"""General plotting functions for sctoolbox, e.g. general plots for wrappers, and saving and adding titles to figures."""
import pandas as pd
import numpy as np
import warnings
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib_venn import venn2, venn3
import scipy.cluster.hierarchy as sciclust
import seaborn
from beartype import beartype
from beartype.typing import Iterable, Optional, Literal, Tuple, Union, Any
from sctoolbox import settings
########################################################################################
# -------------------- General helper functions for plotting ------------------------- #
########################################################################################
@beartype
def _save_figure(path: Optional[str],
dpi: int = 600,
**kwargs: Any) -> None:
"""Save the current figure to a file.
Parameters
----------
path : Optional[str]
Path to the file to be saved. NOTE: Uses the internal 'sctoolbox.settings.figure_dir' + 'sctoolbox.settings.figure_prefix' as prefix.
Add the extension (e.g. .tiff) you want save your figure in to the end of the path, e.g., /some/path/plot.tiff.
The lack of extension indicates the figure will be saved as .png.
dpi : int, default 600
Dots per inch. Higher value increases resolution.
**kwargs : Any
Additional arguments to pass to matplotlib.pyplot.savefig.
"""
savefig_kwargs = {"bbox_inches": "tight", "facecolor": "white"} # defaults
savefig_kwargs.update(kwargs)
# 'path' can be None if _save_figure was used within a plotting function, and the internal 'save' was "None".
# This moves the checking to the _save_figure function rather than each plotting function.
if path is not None:
output_path = settings.full_figure_prefix + path
plt.savefig(output_path, dpi=dpi, **savefig_kwargs)
@beartype
def _make_square(ax: matplotlib.axes.Axes) -> None:
"""Force a plot to be square using aspect ratio regardless of the x/y ranges."""
xrange = np.diff(ax.get_xlim())[0]
yrange = np.diff(ax.get_ylim())[0]
aspect = xrange / yrange
ax.set_aspect(aspect)
@beartype
def _add_figure_title(axarr: Iterable[matplotlib.axes.Axes] | matplotlib.axes.Axes | seaborn.matrix.ClusterGrid,
title: str,
y: float | int = 1.3,
fontsize: int = 16) -> None:
"""Add a figure title to the top of a multi-axes figure.
Parameters
----------
axarr : Iterable[matplotlib.axes.Axes] | matplotlib.axes.Axes
List of axes to add the title to.
title : str
Title to add at the top of plot.
y : float | int, default 1.3
Vertical position of the title in relation to the content. Larger number moves the title further up.
fontsize : int, default 16
Font size of the title.
Examples
--------
.. plot::
:context: close-figs
axes = sc.pl.umap(adata, color=["louvain", "condition"], show=False)
pl.add_figure_title(axes, "UMAP plots", fontsize=20)
"""
# If only one axes is passed, convert to list
if type(axarr).__name__.startswith("Axes"):
axarr = [axarr]
try:
axarr[0]
except Exception:
if isinstance(axarr, dict):
ax_dict = axarr # e.g. scanpy dotplot
else:
ax_dict = axarr.__dict__ # seaborn clustermap, etc.
axarr = [ax_dict[key] for key, value in ax_dict.items() if type(value).__name__.startswith("Axes")]
# Get figure
fig = plt.gcf()
fig.canvas.draw()
renderer = fig.canvas.get_renderer()
# Get bounding box of axes in relation to first axes
trans_data_inv = axarr[0].transData.inverted() # from display to data
bbox_list = [ax.get_window_extent(renderer=renderer).transformed(trans_data_inv) for ax in axarr]
# Find y/x positions based on bboxes
ty = np.max([bbox.y1 for bbox in bbox_list])
ty *= y
xmin = np.min([bbox.x0 for bbox in bbox_list])
xmax = np.max([bbox.x1 for bbox in bbox_list])
tx = np.mean([xmin, xmax])
# Add text
_ = axarr[0].text(tx, ty, title, va="bottom", ha="center", fontsize=fontsize)
@beartype
def _add_labels(data: pd.DataFrame,
x: str,
y: str,
label_col: Optional[str] = None,
ax: Optional[matplotlib.axes.Axes] = None,
**kwargs: Any) -> list:
"""Add labels to a scatter plot.
Parameters
----------
data : pd.DataFrame
Dataframe containing the coordinates of points to label.
x : str
Name of the column in data to use for x axis coordinates.
y : str
Name of the column in data to use for y axis coordinates.
label_col : str, default None
Name of the column in data to use for labels. If `None`, the index of data is used.
ax : matplotlib.axes.Axes, default None
Axis to plot on. If `None`, the current open figure axis is used.
**kwargs : Any
Additional arguments to pass to matplotlib.axes.Axes.annotate.
Returns
-------
list
List of matplotlib.text.Annotation objects.
"""
if ax is None:
ax = plt.gca()
x_coords = data[x]
y_coords = data[y]
labels = data.index if label_col is None else data[label_col]
texts = []
for i, label in enumerate(labels):
text = ax.annotate(label, (x_coords[i], y_coords[i]), **kwargs)
texts.append(text)
# Adjust text positions
# to-do
return texts
#############################################################################
# ------------------------------ Dotplot ---------------------------------- #
#############################################################################
[docs]
@beartype
def clustermap_dotplot(table: pd.DataFrame,
x: str,
y: str,
size: str,
hue: str,
cluster_on: Literal["hue", "size"] = "hue",
fillna: float | int = 0,
title: Optional[str] = None,
figsize: Optional[Tuple[int | float, int | float]] = None,
dend_height: float | int = 2,
dend_width: float | int = 2,
palette: str = "vlag",
x_rot: int = 45,
show_grid: bool = False,
save: Optional[str] = None,
**kwargs: Any) -> list:
"""
Plot a heatmap with dots (instead of squares), which can contain the dimension of "size".
Parameters
----------
table : pd.DataFrame
Table in long-format. Has to have at least four columns as given by x, y, size and hue.
x : str
Column in table to plot on the x-axis.
y : str
Column in table to plot on the y-axis.
size : str
Column in table to use for the size of the dots.
hue : str
Column in table to use for the color of the dots.
cluster_on : Literal["hue", "size"], default hue
Decide which values to use for creating the dendrograms. Either "hue" or "size".
fillna : float | int, default 0
Replace NaN with given value.
title : Optional[str], default None
Title of the dotplot.
figsize : Optional[Tuple[int | float, int | float]], default None
Figure size in inches. Default is estimated from the number of rows/columns (ncols/3, nrows/3).
dend_height : float | int, default 2
Height of the x-axis dendrogram in counts of row elements, e.g. 2 represents a height of 2 rows in the dotplot.
dend_width : float | int, default 2
Width of the y-axis dendrogram in counts of column elements, e.g. 2 represents a width of 2 columns in the dotplot.
palette : str, default vlag
Color palette for hue colors.
x_rot : int, default 45
Rotation of xticklabels in degrees.
show_grid : bool, default False
Show grid behind dots in plot.
save : Optional[str], default None
Save the figure to this path.
**kwargs : Any
Additional arguments to pass to seaborn.scatterplot.
Returns
-------
list
List of matplotlib.axes.Axes objects containing the dotplot and the dendrogram(s).
Examples
--------
.. plot::
:context: close-figs
table = adata.obs.reset_index()[:10]
.. plot::
:context: close-figs
pl.clustermap_dotplot(
table=table,
x="bulk_labels",
y="index",
hue="n_genes",
size="n_counts",
palette="viridis"
)
"""
table = table.copy()
# long table to wide format for hue and size
wide_hue = pd.pivot(data=table, index=y, columns=x, values=hue).fillna(fillna)
wide_size = pd.pivot(data=table, index=y, columns=x, values=size).fillna(fillna)
nrows, ncols = wide_hue.shape # same shape as wide_size
# decide what dendrograms are possible
x_dend_possible = len((wide_hue if cluster_on == "hue" else wide_size).columns) > 1
y_dend_possible = len(wide_hue if cluster_on == "hue" else wide_size) > 1
# Set figsize automatically
if figsize is None:
figsize = (ncols / 3, nrows / 3)
# Create figure
fig, ax = plt.subplots(1, figsize=figsize)
axes = [ax]
# Prepare shape of dotplot
ax.set_xlim(-0.5, ncols - 0.5)
ax.set_ylim(-0.5, nrows - 0.5)
ax.set_xticks(np.arange(ncols))
ax.set_aspect(1)
# x-axis dendrogram
if x_dend_possible:
x_link = sciclust.linkage(wide_hue.T if cluster_on == "hue" else wide_size.T)
# Plot dendrogram
col_dend_ax = ax.inset_axes([0, 1, 1, dend_height / nrows]) # column dendrogram
axes.append(col_dend_ax)
x_dend = sciclust.dendrogram(x_link,
orientation="top",
labels=wide_hue.columns if cluster_on == "hue" else wide_size.columns,
no_labels=True,
link_color_func=lambda x: "black", # disable cluster colors
ax=col_dend_ax)
col_dend_ax.axis("off")
# order after dendrogram
# (sharey parameter is bugged)
# https://towardsdatascience.com/how-to-do-a-custom-sort-on-pandas-dataframe-ac18e7ea5320
x_order = pd.CategoricalDtype(
reversed(x_dend["ivl"]),
ordered=True
)
table[x] = table[x].astype(x_order)
# y-axis dendrogram
if y_dend_possible:
y_link = sciclust.linkage(wide_hue if cluster_on == "hue" else wide_size)
# Plot dendrogram
row_dend_ax = ax.inset_axes([1, 0, dend_width / ncols, 1]) # row dendrogram
axes.append(row_dend_ax)
y_dend = sciclust.dendrogram(y_link,
orientation="right",
labels=wide_hue.T.columns if cluster_on == "hue" else wide_size.T.columns,
no_labels=True,
link_color_func=lambda x: "black", # disable cluster colors
ax=row_dend_ax)
row_dend_ax.axis("off")
# order after dendrogram
# (sharey parameter is bugged)
# https://towardsdatascience.com/how-to-do-a-custom-sort-on-pandas-dataframe-ac18e7ea5320
y_order = pd.CategoricalDtype(
reversed(y_dend["ivl"]),
ordered=True
)
table[y] = table[y].astype(y_order)
# sort matrix according to the dendrogram orders
table = table.sort_values([x, y])
# Fill in axes with dotplot
plot = sns.scatterplot(data=table,
y=y,
x=x,
size=size,
sizes=(10, 200),
hue=hue,
palette=palette,
ax=ax,
zorder=100, # place points above grid
**kwargs
)
ax.set_xticklabels(ax.get_xticklabels(), rotation=x_rot, ha="right" if x_rot != 0 else "center")
# Move legend to right side
x_anchor = 1 if nrows == 1 else 1 + dend_width / ncols
sns.move_legend(plot, loc='upper left', bbox_to_anchor=(x_anchor, 1, 0, 0))
# Show gridlines
if show_grid:
ax.grid()
# Title above plot
title_ax = col_dend_ax if x_dend_possible else ax
title_ax.set_title(title)
# Save figure
_save_figure(save)
return axes
########################################################################################
# Barplot #
########################################################################################
[docs]
@beartype
def bidirectional_barplot(df: pd.DataFrame,
title: Optional[str] = None,
colors: Optional[dict[str, str]] = None,
figsize: Optional[Tuple[int | float, int | float]] = None,
save: Optional[str] = None) -> matplotlib.axes.Axes:
"""Plot a bidirectional barplot.
A vertical barplot where each position has one bar going left and one going right (bidirectional).
Parameters
----------
df : pd.DataFrame
Dataframe with the following mandatory column names:
- left_label
- right_label
- left_value
- right_value
title : Optional[str], default None
Title of the plot.
colors : Optional[dict[str, str]], default None
Dictionary with label names as keys and colors as values.
figsize : Optional[Tuple[int | float, int | float]], default None
Figure size.
save : Optional[str], default None
If given, the figure will be saved to this path.
Returns
-------
matplotlib.axes.Axes
Axes containing the plot.
Raises
------
KeyError
If df does not contain the required columns.
"""
# Check that df contains columns left/right_label and left/right value
required_columns = ["left_label", "right_label", "left_value", "right_value"]
for col in required_columns:
if col not in df.columns:
raise KeyError(f"Column {col} not found in dataframe.")
# Example data
labels_left = df["left_label"].tolist()
labels_right = df["right_label"].tolist()
values_left = -np.abs(df["left_value"])
values_right = df["right_value"]
if colors is None:
all_labels = list(set(labels_left + labels_right))
colors = {label: sns.color_palette()[i] for i, label in enumerate(all_labels)}
# Create figure and axis objects
if figsize is None:
figsize = (5, len(labels_left)) # 5 wide, n bars tall
fig, ax = plt.subplots(figsize=figsize)
# Set the position of the y-axis ticks
n_bars = len(labels_left)
yticks = np.arange(n_bars)[::-1]
# Plot the positive values as blue bars
right_colors = [colors[label] for label in labels_right]
right_bars = ax.barh(yticks, values_right, color=right_colors)
# Plot the negative values as red bars
left_colors = [colors[label] for label in labels_left]
left_bars = ax.barh(yticks, values_left, color=left_colors)
# Set the x-axis limits to include both positive and negative values
ax.set_xlim([min(values_left) * 1.1, max(values_right) * 1.1])
# Add a vertical line at x=0 to indicate the zero point
ax.axvline(x=0, color='k')
# Add text labels and values to right bars
for i, bar in enumerate(right_bars):
ax.text(bar.get_width(), bar.get_y() + bar.get_height() / 2, " " + str(labels_right[i]), ha='left', va='center') # adding a space before to ensure space between bars and labels
ax.text(bar.get_width() / 2, bar.get_y() + bar.get_height() / 2, str(values_right[i]), ha='center', va='center')
# Add text labels and values to left bars
for i, bar in enumerate(left_bars):
ax.text(bar.get_width(), bar.get_y() + bar.get_height() / 2, str(labels_left[i]) + " ", ha='right', va='center')
ax.text(bar.get_width() / 2, bar.get_y() + bar.get_height() / 2, str(np.abs(values_left[i])), ha='center', va='center')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_position(('data', 0))
ax.set_yticks([])
ax.set_yticklabels([])
# Set the x-axis tick labels to be positive numbers
ticks = ax.get_xticks().tolist()
ax.set_xticks(ticks) # prevent "FixedFormatter should only be used together with FixedLocator"
ax.set_xticklabels([int(abs(tick)) for tick in ticks])
# Add a legend
# ax.legend(['Positive', 'Negative'], loc='center left', bbox_to_anchor=(1, 0.5))
# ax.legend(['Positive', 'Negative'], loc='lower right')
if title is not None:
ax.set_title(title)
# Save figure
_save_figure(save)
return ax
########################################################################################
# ----------------------------- Boxplot / violinplot -------------------------------- #
########################################################################################
[docs]
@beartype
def boxplot(dt: pd.DataFrame,
show_median: bool = True,
ax: Optional[matplotlib.axes.Axes] = None,
**kwargs: Any) -> matplotlib.axes.Axes:
"""Generate one plot containing one box per column. The median value is shown.
Parameters
----------
dt : pd.DataFrame
pandas datafame containing numerical values in every column.
show_median : boolean, default True
If True show median value as small box inside the boxplot.
ax : Optional[matplotlib.axes.Axes], default None
Axes object to plot on. If None, a new figure is created.
**kwargs : Any
Additional arguments to pass to seaborn.boxplot.
Returns
-------
matplotlib.axes.Axes
containing boxplot for every column.
Examples
--------
.. plot::
:context: close-figs
import pandas as pd
dt = pd.DataFrame(np.random.randint(0,100,size=(100, 4)), columns=list('ABCD'))
.. plot::
:context: close-figs
pl.boxplot(dt, show_median=True, ax=None)
"""
if ax is None:
fig, ax = plt.subplots()
else:
# TODO: check if ax is an ax object
pass
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning, message="iteritems is deprecated*")
dt_melt = dt.melt()
ax = sns.boxplot(data=dt_melt, x="variable", y="value", ax=ax, **kwargs)
ax.set_xlabel("")
ax.set_ylabel("")
if show_median:
# From:
# https://stackoverflow.com/questions/49554139/boxplot-of-multiple-columns-of-a-pandas-dataframe-on-the-same-figure-seaborn
lines = ax.get_lines()
categories = ax.get_xticks()
# Add median label
for cat in categories:
y = round(lines[4 + cat * 6].get_ydata()[0], 2)
ax.text(cat, y, f'{y}', ha='center', va='center', fontweight='bold', size=10, color='white',
bbox=dict(facecolor='#445A64'))
return ax
[docs]
@beartype
def violinplot(table: pd.DataFrame,
y: str,
color_by: Optional[str] = None,
hlines: Optional[Union[float | int,
list[float | int],
dict[str, Union[float | int, list[float | int]]]]] = None,
colors: Optional[list[str]] = None,
ax: Optional[matplotlib.axes.Axes] = None,
title: Optional[str] = None,
ylabel: bool = True,
**kwargs: Any) -> matplotlib.axes.Axes:
"""Plot a violinplot with optional horizontal lines for each violin.
Parameters
----------
table : pd.DataFrame
Values to create the violins from.
y : str
Column name of table. Values that will be shown on y-axis.
color_by : Optional[str], default None
Column name of table. Used to color group violins.
hlines : Optional[Union[float | int, list[float | int],
dict[str, Union[float | int, list[float | int]]]]], default None
Define horizontal lines for each violin.
colors : Optional[list[str]], default None
List of colors to use for violins.
ax : Optional[matplotlib.axes.Axes], default None
Axes object to draw the plot on. Otherwise use current axes.
title : Optional[str], default None
Title of the plot.
ylabel : bool | str, default True
Boolean if ylabel should be shown. Or str for custom ylabel.
**kwargs : Any
Additional arguments to pass to seaborn.violinplot.
Returns
-------
matplotlib.axes.Axes
Object containing the violinplot.
Raises
------
ValueError
If y or color_by is not a column name of table. Or if hlines is not a number or list of numbers for color_by=None.
Examples
--------
.. plot::
:context: close-figs
import seaborn as sns
table = sns.load_dataset("titanic")
.. plot::
:context: close-figs
pl.violinplot(table, "age", color_by="class", hlines=None, colors=None, ax=None, title=None, ylabel=True)
"""
# check if valid column name
if y not in table.columns:
raise ValueError(f"{y} not found in column names of table! Use one of {list(table.columns)}.")
# check if color_by is valid volumn name
if color_by is not None and color_by not in table.columns:
raise ValueError(f"Color grouping '{color_by}' not found in column names of table! Use one of {list(table.columns)}")
# set violin order
color_group_order = set(table[color_by]) if color_by is not None else color_by
# hlines has to be number of list if color_by=None
if hlines is not None and color_by is None and not isinstance(hlines, (list, tuple, int, float)):
raise ValueError(f"Parameter hlines has to be number or list of numbers for color_by=None. Got type {type(hlines)}.")
# check valid groups in hlines dict
if isinstance(hlines, dict):
invalid_keys = set(hlines.keys()) - set(table.columns)
if invalid_keys:
raise ValueError(f"Invalid dict keys in hlines parameter. Key(s) have to match table column names. Invalid keys: {invalid_keys}")
# create violinplot
plot = sns.violinplot(data=table, y=y, x=color_by, order=color_group_order, color=colors, ax=ax, **kwargs)
# add horizontal lines
if hlines:
# add color_group_order placeholder
if color_by is None:
color_group_order = [None]
# make iterable
hlines = hlines if isinstance(hlines, (list, tuple, dict)) else [hlines]
# make hlines dict
hlines_dict = hlines if isinstance(hlines, dict) else {}
# horizontal line length computation
violin_width = 1 / len(color_group_order)
line_length = violin_width - 2 * violin_width * 0.1 # subtract 10% padding
half_length = line_length / 2
# draw line(s) for each violin
for i, violin_name in enumerate(color_group_order):
violin_center = violin_width * (i + 1) - violin_width / 2
# ensure iterable
line_heights = hlines_dict.setdefault(violin_name, hlines)
line_heights = line_heights if isinstance(line_heights, (list, tuple)) else [line_heights]
for line_height in line_heights:
# skip if invalid line_height
if not isinstance(line_height, (int, float)):
continue
# add to right axes
tmp_ax = ax if ax else plot
tmp_ax.axhline(y=line_height,
xmin=violin_center - half_length,
xmax=violin_center + half_length,
color="orange",
ls="dashed",
lw=3)
# add title
if title:
plot.set(title=title)
# adjust y-label
if not ylabel:
plot.set(ylabel=None)
elif isinstance(ylabel, str):
plot.set(ylabel=ylabel)
# remove x-axis ticks if color_by=None
if color_by is None:
plot.tick_params(axis="x", which="both", bottom=False)
return plot
########################################################################################
# -------------------------------- Venn diagrams ------------------------------------ #
########################################################################################
[docs]
@beartype
def plot_venn(groups_dict: dict[str, list[Any]],
title: Optional[str] = None,
save: Optional[str] = None,
**kwargs: Any) -> None:
"""Plot a Venn diagram from a dictionary of 2-3 groups of lists.
Parameters
----------
groups_dict : dict[str, list[Any]]
A dictionary where the keys are group names (strings) and the values
are lists of items belonging to that group (e.g. {'Group A': ['A', 'B', 'C'], ...}).
title : Optional[str], default None
Title of the plot.
save : Optional[str], default None
Filename to save the plot to.
**kwargs : Any
Additional arguments to pass to matplotlib_venn.venn2 or matplotlib_venn.venn3.
Raises
------
ValueError
If number of groups in groups_dict is not 2 or 3.
Examples
--------
.. plot::
:context: close-figs
venn2_example = { 'Group A': [1, 2, 3, 4],
'Group B': [3, 4, 5, 6]
}
.. plot::
:context: close-figs
pl.plot_venn(venn2_example, "Simple Venn2 plot")
.. plot::
:context: close-figs
venn3_example = { 'Fruits A': ['Lemon', 'Orange', 'Blueberry', 'Grapefruit'],
'Fruits B': ['Pineapple', 'Mango', 'Banana', 'Papaya', 'Blueberry', 'Strawberry'],
'Fruits C': ['Strawberry', 'Blueberry', 'Raspberry', 'Orange', 'Mango']
}
.. plot::
:context: close-figs
pl.plot_venn(venn3_example, "Simple Venn3 plot")
"""
# Extract the lists of items from the dictionary and convert them to sets
group_sets = [set(groups_dict[group]) for group in groups_dict]
plt.figure()
# Plot the Venn diagram using matplotlib_venn
if len(group_sets) == 2:
venn2(group_sets, set_labels=list(groups_dict.keys()), **kwargs)
elif len(group_sets) == 3:
venn3(group_sets, set_labels=list(groups_dict.keys()), **kwargs)
else:
raise ValueError("Only 2 or 3 groups are supported.")
# Add a title to the plot
if title is not None:
plt.title(title)
# Show the plot
_save_figure(save)
########################################################################################
# -------------------------------- Scatter plots ------------------------------------- #
########################################################################################
[docs]
@beartype
def pairwise_scatter(table: pd.DataFrame,
columns: list[str],
thresholds: Optional[dict[str, dict[Literal["min", "max"], int | float]]] = None,
save: Optional[str] = None,
**kwargs: Any) -> np.ndarray:
"""Plot a grid of scatterplot comparing column values pairwise.
If thresholds are given, lines are drawn for each threshold and points outside of the thresholds are colored red.
Parameters
----------
table : pd.DataFrame
Dataframe containing the data to plot.
columns : list[str]
List of column names in table to plot.
thresholds : Optional[dict[str, dict[Literal["min", "max"], int | float]]], default None
Dictionary containing thresholds for each column. Keys are column names and values are dictionaries with keys "min" and "max".
save : Optional[str], default None
If given, the figure will be saved to this path.
**kwargs : Any
Additional arguments to pass to matplotlib.axes.Axes.scatter.
Returns
-------
np.ndarray
Array of matplotlib.axes.Axes objects.
Raises
------
ValueError
1. If columns contains less than two columns.
2. If one of the given columns is not a table column
Examples
--------
.. plot::
:context: close-figs
columns = ["percent_mito", "n_counts", "S_score"]
thresholds = {"n_counts": {"min": 2500, "max": 8000},
"percent_mito": {"max": 0.03},
"S_score": {"max": 0.5}}
pl.pairwise_scatter(adata.obs, columns, thresholds=thresholds)
"""
if len(columns) < 2:
raise ValueError("'columns' must contain at least two columns to compare.")
for col in columns:
if col not in table.columns:
raise ValueError(f"Column '{col}' not found in table.")
if thresholds is None:
thresholds = {}
# Initialize plot
fig, axarr = plt.subplots(nrows=len(columns), ncols=len(columns),
figsize=(len(columns) * 3, len(columns) * 3))
# Fill in plots
excluded_flag = False
for i_row in range(len(columns)): # iterate over rows
for i_col in range(len(columns)): # iterate over columns
c_col, c_row = columns[i_col], columns[i_row]
ax = axarr[i_row, i_col]
if i_row == i_col: # plot histogram
sns.histplot(table[c_col], ax=ax, color="black")
ax.set_xlabel("") # labels are set afterwards
ax.set_ylabel("") # labels are set afterwards
else:
# Establish coloring using thresholds
included = np.ones(len(table), dtype=bool)
for col in [c_col, c_row]:
if col in thresholds:
included = included & (table[col] >= thresholds[col].get("min", table[col].min())) & (table[col] <= thresholds[col].get("max", table[col].max()))
colors = np.where(included, "black", "red")
ax.scatter(table[c_col], table[c_row], s=1, c=colors, **kwargs) # x=columns, y=rows
excluded_flag = excluded_flag or not np.all(included) # set flag if any points are excluded
# Plot threshold lines
for i, col in enumerate(columns):
if col in thresholds:
for key in ["min", "max"]:
if key in thresholds[col]:
# plot vertical lines in row
for ax in axarr[:, i]:
ax.axvline(thresholds[col][key], color="darkgrey", lw=1, linestyle="--")
# plot horizontal lines in scatterplots
scatter_idx = [i_col for i_col in range(len(columns)) if i_col != i] # all but current column
for ax in axarr[i, scatter_idx]:
ax.axhline(thresholds[col][key], color="darkgrey", lw=1, linestyle="--")
# Fix y-axis legends for first histogram
ax = axarr[0, 0].twinx() # create new axis for correct y-values
ax.set_ylim(axarr[0, 1].get_ylim())
ax.yaxis.set_label_position('left')
ax.yaxis.set_ticks_position('left')
axarr[0, 0].set_yticks([]) # remove original axis
axarr[0, 0] = ax
# Set labels
for i, col in enumerate(columns):
axarr[i, 0].set_ylabel(col) # left column contains y labels
axarr[-1, i].set_xlabel(col) # bottom row contains x labels
# Remove ticklabels from middle plots
_ = [ax.axes.yaxis.set_ticklabels([]) for ax in axarr[:, 1:].flatten()] # remove y ticklabels from all but first column
_ = [ax.axes.xaxis.set_ticklabels([]) for ax in axarr[:-1, :].flatten()] # remove x ticklabels from all but last row
# Add legend if any points are excluded
if excluded_flag:
point = Line2D([0], [0], marker='o', markersize=np.sqrt(20), color='r', linestyle='None')
axarr[0, -1].legend([point], ["Excluded"], loc="center left", bbox_to_anchor=(1, 0.5), frameon=False)
plt.subplots_adjust(wspace=0.08, hspace=0.08)
# Save plot
_save_figure(save)
return axarr