"""Functions for plotting QC-related figures e.g. number of cells per group and violins."""
import pandas as pd
import copy
import numpy as np
import ipywidgets
import traitlets
import functools # for partial functions
import glob
import scanpy as sc
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import sctoolbox.utils as utils
from sctoolbox.plotting.general import _save_figure
import sctoolbox.utils.decorator as deco
# type hint imports
from beartype.typing import Tuple, Dict, Optional, Literal, Callable, Iterable, Any # , Union, List
from beartype import beartype
########################################################################################
# ------------------------------ QC plots for starsolo ------------------------------- #
########################################################################################
@beartype
def _read_starsolo_summary(folder: str) -> pd.DataFrame:
"""Get summary table from an output folder containing multiple starsolo runs.
Parameters
----------
folder : str
Path to a folder, e.g. "path/to/starsolo_output", which contains folders "solorun1", "solorun2", etc.
Raises
------
ValueError
If no summary files are found in the folder.
Returns
-------
summary_table : pd.DataFrame
Table with summary statistics from all runs.
"""
summary_files = glob.glob(folder + "/**/solo/Gene/Summary.csv")
if len(summary_files) == 0:
raise ValueError(f"No STARsolo summary files found in folder '{folder}'. Please check the path and try again.")
# Read statistics from summary files
names = utils.clean_flanking_strings(summary_files)
summary_tables = []
for name, f in zip(names, summary_files):
star_table = pd.read_csv(f, index_col=0, header=None, names=[name])
summary_tables.append(star_table)
summary_table = pd.concat(summary_tables, axis=1)
return summary_table
[docs]
@beartype
def plot_starsolo_quality(folder: str,
measures: list[str] = ["Number of Reads", "Reads Mapped to Genome: Unique",
"Reads Mapped to Gene: Unique Gene", "Fraction of Unique Reads in Cells",
"Median Reads per Cell", "Median Gene per Cell"],
ncol: int = 3,
order: Optional[list[str]] = None,
save: Optional[str] = None,
**kwargs: Any) -> np.ndarray:
"""Plot quality measures from starsolo as barplots per condition.
Parameters
----------
folder : str
Path to a folder, e.g. "path/to/starsolo_output", which contains folders "solorun1", "solorun2", etc.
measures : list[str], default ["Number of Reads", "Reads Mapped to Genome: Unique", "Reads Mapped to Gene: Unique Gene", "Fraction of Unique Reads in Cells", "Median Reads per Cell", "Median Gene per Cell"]
List of measures to plot. Must be available in the solo summary table.
ncol : int, default 3
Number of columns in the plot.
order : Optional[list[str]], default None
Order of conditions in the plot. If None, the order is alphabetical.
save : Optional[str], default None
Path to save the plot. If None, the plot is not saved.
**kwargs : Any
Additional arguments passed to seaborn.barplot.
Returns
-------
axes : np.ndarray
Array of axes objects containing the plot(s).
Raises
------
KeyError
If a measure is not available in the solo summary table.
Examples
--------
.. plot::
:context: close-figs
pl.plot_starsolo_quality("data/quant/")
"""
# Prepare functions for converting labels
def format_million(label):
return '{:,.0f} M'.format(int(label) / 10**6)
def format_thousand(label):
return '{:,.0f} K'.format(int(label) / 10**3)
def format_percent(label):
return '{:,.0f}%'.format(float(label) * 100)
# Get summary table
summary_table = _read_starsolo_summary(folder)
available_measures = summary_table.index.tolist()
if order is None:
order = sorted(summary_table.columns.tolist())
summary_table = summary_table[order]
else:
summary_table = summary_table[order]
# Setup plot
ncol = min(ncol, len(measures))
row = int(np.ceil(len(measures) / ncol))
fig, axes = plt.subplots(row, ncol, figsize=(ncol * 4, row * 4))
axes = axes.flatten() if len(measures) > 1 else np.array([axes]) # axes is a list of axes objects
_ = [ax.axis('off') for ax in axes[len(measures):]] # hide additional plots
# Fill in plot per measure
for i, measure in enumerate(measures):
if measure not in available_measures:
raise KeyError(f"Measure '{measure}' not found in summary table. Available measures: {available_measures}")
# Plot data to barplot
ax = axes[i]
data = summary_table.loc[measure].astype(float)
sns.barplot(x=data.index, y=data.values, ax=ax, edgecolor="black", **kwargs)
ax.set_title(measure)
# Format yticklabels
if data.max() < 1: # convert to %
ax.set_ylim(0, 1)
ax.set_yticks(ax.get_yticks(), [format_percent(value) for value in ax.get_yticks()])
elif data.max() < 10000:
pass # no format; show raw values
elif data.max() < 10**6: # convert to thousands
ax.set_yticks(ax.get_yticks(), [format_thousand(value) for value in ax.get_yticks()])
else: # convert to millions
ax.set_yticks(ax.get_yticks(), [format_million(value) for value in ax.get_yticks()])
ax.set_xticks(ax.get_xticks()) # prevent locator error
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
fig.tight_layout()
_save_figure(save)
return axes
[docs]
@beartype
def plot_starsolo_UMI(folder: str,
ncol: int = 3,
save: Optional[str] = None) -> np.ndarray:
"""Plot UMI distribution for each condition in a folder.
Parameters
----------
folder : str
Path to a folder, e.g. "path/to/starsolo_output", which contains folders "solorun1", "solorun2", etc.
ncol : int, default 3
Number of columns in the plot.
save : Optional[str], default None
Path to save the plot. If None, the plot is not saved.
Returns
-------
axes : np.ndarray
Array of axes objects containing the plot(s).
Raises
------
ValueError
If no UMI files ('UMIperCellSorted.txt') are found in the folder.
Examples
--------
.. plot::
:context: close-figs
pl.plot_starsolo_UMI("data/quant/", ncol=2)
"""
summary_table = _read_starsolo_summary(folder)
umi_files = glob.glob(folder + "/**/solo/Gene/UMIperCellSorted.txt")
if len(umi_files) == 0:
raise ValueError("No UMI files found in folder. Please check the path and try again.")
names = utils.clean_flanking_strings(umi_files)
# Setup plot
ncol = min(len(names), ncol)
nrow = int(np.ceil(len(names) / ncol))
fig, axes = plt.subplots(nrow, ncol, figsize=(ncol * 4, nrow * 4))
axes = axes.flatten() if len(names) > 1 else np.array([axes]) # axes is a list of axes objects
_ = [ax.axis('off') for ax in axes[len(names):]] # hide additional plots
for i, f in enumerate(umi_files):
ax = axes[i]
name = names[i]
df_knee = pd.read_table(f, header=None, names=[name])
cut = int(summary_table.loc["Estimated Number of Cells", name])
df_knee.plot.line(logx=True, logy=True, legend=False, ax=ax)
df_knee[:cut].plot.line(logx=True, legend=False, ax=ax, color='red')
ax.axvline(x=cut, color='grey', linestyle='-')
vmax = df_knee.iloc[0, 0]
ax.text(cut * 1.2, vmax, str(cut) + ' cells', verticalalignment='center')
ax.set_title(name)
ax.set_xlabel('Barcodes')
ax.set_ylabel('UMI count')
fig.tight_layout()
_save_figure(save)
return axes
########################################################################################
# ---------------------------- Plots for counting cells ------------------------------ #
########################################################################################
@deco.log_anndata
@beartype
def _n_cells_pieplot(adata: sc.AnnData,
groupby: str,
figsize: Optional[Tuple[int | float, int | float]] = None):
"""
Plot number of cells per group in a pieplot.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
groupby : str
Name of the column in adata.obs to group by.
figsize : tuple, default None
Size of figure, e.g. (4, 8).
"""
# Get counts
counts = adata.obs[groupby].value_counts()
counts
# in progress
[docs]
@deco.log_anndata
@beartype
def n_cells_barplot(adata: sc.AnnData,
x: str,
groupby: Optional[str] = None,
stacked: bool = True,
save: Optional[str] = None,
figsize: Optional[Tuple[int | float, int | float]] = None,
add_labels: bool = False,
**kwargs: Any) -> Iterable[matplotlib.axes.Axes]:
"""
Plot number and percentage of cells per group in a barplot.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
x : str
Name of the column in adata.obs to group by on the x axis.
groupby : Optional[str], default None
Name of the column in adata.obs to created stacked bars on the y axis. If None, the bars are not split.
stacked : bool, default True
Whether to stack the bars or not.
save : Optional[str], default None
Path to save the plot. If None, the plot is not saved.
figsize : Optional[Tuple[int | float, int | float]], default None
Size of figure, e.g. (4, 8). If None, size is determined automatically depending on whether groupby is None or not.
add_labels : bool, default False
Whether to add labels to the bars giving the number/percentage of cells.
**kwargs : Any
Additional arguments passed to pandas.DataFrame.plot.bar.
Returns
-------
axarr : Iterable[matplotlib.axes.Axes]
Array of axes objects containing the plot(s).
Examples
--------
.. plot::
:context: close-figs
pl.n_cells_barplot(adata, x="louvain")
.. plot::
:context: close-figs
pl.n_cells_barplot(adata, x="louvain", groupby="condition")
"""
# Get cell counts for groups or all
tables = []
if groupby is not None:
for i, frame in adata.obs.groupby(groupby):
count = frame.value_counts(x).to_frame(name="count").reset_index()
count["groupby"] = i
tables.append(count)
counts = pd.concat(tables)
else:
counts = adata.obs[x].value_counts().to_frame(name="count").reset_index()
counts.rename(columns={"index": x}, inplace=True)
counts["groupby"] = "all"
# Format counts
counts_wide = counts.pivot(index=x, columns="groupby", values="count")
counts_wide_percent = counts_wide.div(counts_wide.sum(axis=1), axis=0) * 100
# Plot barplots
if figsize is None:
figsize = (5 + 5 * (groupby is not None), 3) # if groupby is not None, add 5 to width
if groupby is not None:
_, axarr = plt.subplots(1, 2, figsize=figsize)
else:
_, axarr = plt.subplots(1, 1, figsize=figsize) # axarr is a single axes
axarr = [axarr]
counts_wide.plot.bar(stacked=stacked, ax=axarr[0], legend=False, **kwargs)
axarr[0].set_title("Number of cells")
axarr[0].set_xticklabels(axarr[0].get_xticklabels(), rotation=45, ha="right")
axarr[0].grid(False)
if groupby is not None:
counts_wide_percent.plot.bar(stacked=stacked, ax=axarr[1], **kwargs)
axarr[1].set_title("Percentage of cells")
axarr[1].set_xticklabels(axarr[1].get_xticklabels(), rotation=45, ha="right")
axarr[1].grid(False)
# Set location of legend
axarr[1].legend(title=groupby, bbox_to_anchor=(1, 1), frameon=False,
handlelength=1, handleheight=1 # make legend markers square
)
# Draw line at 100% if values are stacked
if stacked is True:
axarr[1].axhline(100, color='black', linestyle='--', linewidth=0.5, zorder=0)
# Add labels to bars
if add_labels:
for i, ax in enumerate(axarr):
for c in ax.containers:
labels = [v.get_height() if v.get_height() > 0 else '' for v in c] # no label if segment is 0
if i == 0:
labels = [str(int(v)) for v in labels] # convert to int
else:
labels = ["%.1f" % v + "%" for v in labels] # round and add % sign
labels = [label.replace(".0", "") for label in labels] # remove .0 from 100.0%
ax.bar_label(c, labels=labels, label_type='center')
# Remove spines
for ax in axarr:
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
_save_figure(save)
return axarr
[docs]
@deco.log_anndata
@beartype
def group_correlation(adata: sc.AnnData,
groupby: str,
method: Literal["spearman", "pearson", "kendall"] | Callable = "spearman",
save: Optional[str] = None,
**kwargs: Any) -> sns.matrix.ClusterGrid:
"""Plot correlation matrix between groups in `groupby`.
The function expects the count data in .X to be normalized across cells.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
groupby : str
Name of the column in adata.obs to group cells by.
method : Literal["spearman", "pearson", "kendall"] | Callable, default "spearman"
Correlation method to use. See pandas.DataFrame.corr for options.
save : Optional[str], default None
Path to save the plot. If None, the plot is not saved.
**kwargs : Any
Additional arguments passed to seaborn.clustermap.
Returns
-------
sns.matrix.ClusterGrid
Examples
--------
.. plot::
:context: close-figs
import scanpy as sc
import sctoolbox.plotting as pl
.. plot::
:context: close-figs
adata = sc.datasets.pbmc68k_reduced()
.. plot::
:context: close-figs
pl.group_correlation(adata, "phase", method="spearman", save=None)
"""
# Calculate correlation of groups
count_table = utils.pseudobulk_table(adata, groupby=groupby)
corr = count_table.corr(numeric_only=False, method=method)
clustermap_kwargs = {"figsize": (4, 4),
"cmap": "Reds",
"xticklabels": True,
"yticklabels": True,
"cbar_kws": {'orientation': 'horizontal', 'label': method}} # defaults
clustermap_kwargs.update(kwargs) # overwrite defaults with user input
# Plot clustermap
g = sns.clustermap(corr, **clustermap_kwargs)
g.ax_heatmap.set_facecolor("grey")
# Adjust cbar
n = len(corr)
pos = g.ax_heatmap.get_position()
cbar_h = pos.height / n / 2
g.ax_cbar.set_position([pos.x0, pos.y0 - 3 * cbar_h, pos.width, cbar_h])
# Final adjustments
g.ax_col_dendrogram.set_visible(False)
g.ax_heatmap.xaxis.tick_top()
_ = g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=45, ha="left")
_save_figure(save)
return g
#####################################################################
# --------------------------- Insertsize -------------------------- #
#####################################################################
[docs]
@deco.log_anndata
@beartype
def plot_insertsize(adata: sc.AnnData,
barcodes: Optional[list[str]] = None,
**kwargs: Any) -> matplotlib.axes.Axes:
"""
Plot insertsize distribution for barcodes in adata. Requires adata.uns["insertsize_distribution"] to be set.
Parameters
----------
adata : sc.AnnData
AnnData object containing insertsize distribution in adata.uns["insertsize_distribution"].
barcodes : Optional[list[str]], default None
Subset of barcodes to plot information for. If None, all barcodes are used.
**kwargs : Any
Additional arguments passed to seaborn.lineplot.
Returns
-------
ax : matplotlib.axes.Axes
Axes object containing the plot.
Raises
------
ValueError
If adata.uns["insertsize_distribution"] is not set.
"""
if "insertsize_distribution" not in adata.uns:
raise ValueError("adata.uns['insertsize_distribution'] not found!")
insertsize_distribution = copy.deepcopy(adata.uns['insertsize_distribution'])
insertsize_distribution.columns = insertsize_distribution.columns.astype(int)
# Subset barcodes if a list is given
if barcodes is not None:
# Convert to list if only barcode is given
if isinstance(barcodes, str):
barcodes = [barcodes]
table = insertsize_distribution.loc[barcodes].sum(axis=0)
else:
table = insertsize_distribution.sum(axis=0)
# Plot
ax = sns.lineplot(x=table.index, y=table.values, **kwargs)
ax.set_xlabel("Insertsize (bp)")
ax.set_ylabel("Count")
return ax
###########################################################################
# ----------------- Interactive quality control plot -------------------- #
###########################################################################
@beartype
def _link_sliders(sliders: list[ipywidgets.widgets.FloatRangeSlider]) -> list[ipywidgets.link]:
"""Link the values between interactive sliders.
Parameters
----------
sliders : list[ipywidgets.widgets.FloatRangeSlider]
List of sliders to link.
Returns
-------
list[ipywidgets.link]
List of links between sliders.
"""
tup = [(slider, 'value') for slider in sliders]
linkage_list = []
for i in range(1, len(tup)):
link = ipywidgets.link(*tup[i - 1:i + 1])
linkage_list.append(link)
return linkage_list
@beartype
def _toggle_linkage(checkbox: ipywidgets.widgets.Checkbox | traitlets.utils.bunch.Bunch, # after first check, checkbox is a bunch object
linkage_dict: dict,
slider_list: list,
key: str):
"""
Either link or unlink sliders depending on the new value of the checkbox.
Parameters
----------
checkbox : ipywidgets.widgets.Checkbox
Checkbox to toggle linkage.
linkage_dict : dict
Dictionary of links to link or unlink.
slider_list : list of ipywidgets.widgets.Slider
List of sliders to link or unlink.
key : str
Key in linkage_dict for fetching and updating links.
"""
check_bool = checkbox["new"]
if check_bool is True:
if linkage_dict[key] is None: # link sliders if they have not been linked yet
linkage_dict[key] = _link_sliders(slider_list) # overwrite None with the list of links
for linkage in linkage_dict[key]:
linkage.link()
elif check_bool is False:
if linkage_dict[key] is not None: # only unlink if there are links to unlink
for linkage in linkage_dict[key]:
linkage.unlink()
def _update_thresholds(slider, fig, min_line, min_shade, max_line, max_shade):
"""Update the locations of thresholds in plot."""
tmin, tmax = slider["new"] # threshold values from slider
# Update min line
ydata = min_line.get_ydata()
ydata = [tmin for _ in ydata]
min_line.set_ydata(ydata)
x, y = min_shade.get_xy()
min_shade.set_height(tmin - y)
# Update max line
ydata = max_line.get_ydata()
ydata = [tmax for _ in ydata]
max_line.set_ydata(ydata)
x, y = max_shade.get_xy()
max_shade.set_height(tmax - y)
# Draw figure after update
fig.canvas.draw_idle()
# Save figure
# sctoolbox.utilities.save_figure(save)
[docs]
@deco.log_anndata
@beartype
def quality_violin(adata: sc.AnnData,
columns: list[str],
which: Literal["obs", "var"] = "obs",
groupby: Optional[str] = None,
ncols: int = 2,
header: Optional[list[str]] = None,
color_list: Optional[list[str | Tuple[float | int, float | int, float | int]]] = None,
title: Optional[str] = None,
thresholds: Optional[dict[str, dict[str, dict[Literal["min", "max"], int | float]] | dict[Literal["min", "max"], int | float]]] = None,
global_threshold: bool = True,
interactive: bool = True,
save: Optional[str] = None,
**kwargs: Any
) -> Tuple[Any, Dict[str, Any]]:
"""
Plot quality measurements for cells/features in an anndata object.
Notes
-----
Notebook needs "%matplotlib widget" before the call for the interactive sliders to work.
Parameters
----------
adata : sc.AnnData
Anndata object containing quality measures in .obs/.var
columns : list[str]
A list of columns in .obs/.var to show measures for.
which : Literal["obs", "var"], default "obs"
Which table to show quality for. Either "obs" / "var".
groupby : Optional[str], default "condition"
A column in table to values on the x-axis.
ncols : int, default 2
Number of columns in the plot.
header : Optional[list[str]], defaul None
A list of custom headers for each measure given in columns.
color_list : Optional[list[str]], default None
A list of colors to use for violins. If None, colors are chosen automatically.
title : Optional[str], default None
The title of the full plot.
thresholds : Optional[dict[str, dict[str, dict[Literal["min", "max"], int | float]] | dict[Literal["min", "max"], int | float]]], default None
Dictionary containing initial min/max thresholds to show in plot.
global_threshold : bool, default True
Whether to use global thresholding as the initial setting. If False, thresholds are set per group.
interactive : bool, default True
Whether to show interactive sliders. If False, the static matplotlib plot is shown.
save : Optional[str], optional
Save the figure to the path given in 'save'. Default: None (figure is not saved).
**kwargs : Any
Additional arguments passed to seaborn.violinplot.
Returns
-------
Tuple[Any, Dict[str, Any]]
Tuple[Union[matplotlib.figure.Figure, ipywidgets.HBox], Dict[str, Union[List[ipywidgets.FloatRangeSlider.observe], Dict[str, ipywidgets.FloatRangeSlider.observe]]]]
First element contains figure (static) or figure and sliders (interactive). The second element is a nested dict of slider values that are continously updated.
Raises
------
ValueError
If 'which' is not 'obs' or 'var' or if columns are not in table.
"""
is_interactive = utils._is_interactive()
# ---------------- Test input and get ready --------------#
ncols = min(ncols, len(columns)) # Make sure ncols is not larger than the number of columns
nrows = int(np.ceil(len(columns) / ncols))
# Decide which table to use
if which == "obs":
table = adata.obs
elif which == "var":
table = adata.var
# Check that columns are in table
invalid_columns = set(columns) - set(table.columns)
if invalid_columns:
raise ValueError(f"The following columns from 'columns' were not found in '{which}' table: {invalid_columns}")
# Order of categories on x axis
if groupby is not None:
groups = list(table[groupby].astype('category').cat.categories)
n_colors = len(groups)
else:
groups = None
n_colors = 1
# Setup colors to be used
if color_list is None:
color_list = sns.color_palette("Set1", n_colors)
else:
if int(n_colors) > int(len(color_list)):
raise ValueError("Increase the color_list variable to at least {} colors.".format(n_colors))
else:
color_list = color_list[:n_colors]
# Setup headers to be used
if header is None:
header = columns
else:
# check that header has the right length
if len(header) != len(columns):
raise ValueError("Length of header does not match length of columns")
# Setup thresholds if not given
if thresholds is None:
thresholds = {col: {} for col in columns}
# ---------------- Setup figure --------------#
# Setting up output figure
plt.ioff() # prevent plot from showing twice in notebook
if is_interactive:
figsize = (ncols * 3, nrows * 3)
else:
figsize = (ncols * 4, nrows * 4) # static plot can be larger
fig, axarr = plt.subplots(nrows, ncols, figsize=figsize)
axes_list = [axarr] if type(axarr).__name__.startswith("Axes") else axarr.flatten()
# Remove empty axes
for ax in axes_list[len(columns):]:
ax.axis('off')
# Add title of full plot
if title is not None:
fig.suptitle(title)
fontsize = fig._suptitle._fontproperties._size * 1.2 # increase fontsize of title
plt.setp(fig._suptitle, fontsize=fontsize)
# Add title of individual plots
for i in range(len(columns)):
ax = axes_list[i]
ax.set_title(header[i], fontsize=11)
# ------------- Plot data and add sliders ---------#
# Plotting data
slider_dict = {}
linkage_dict = {} # one link list per column
accordion_content = []
for i, column in enumerate(columns):
ax = axes_list[i]
slider_dict[column] = {}
# Plot data from table
sns.violinplot(data=table, x=groupby, y=column, ax=ax, order=groups, palette=color_list, cut=0, **kwargs)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
ax.set_ylabel("")
ax.set_xlabel("")
ticks = ax.get_xticks()
ymin, ymax = ax.get_ylim() # ylim before plotting any thresholds
# Establish groups
if groupby is not None:
group_names = groups
else:
group_names = ["Threshold"]
# Plot thresholds per group
y_range = ymax - ymin
nothresh_min = ymin - y_range * 0.1 # just outside of y axis range
nothresh_max = ymax + y_range * 0.1
data_min = table[column].min()
data_max = table[column].max()
slider_list = []
for j, group in enumerate(group_names):
# Establish the threshold to plot
if column not in thresholds: # no thresholds given
tmin = nothresh_min
tmax = nothresh_max
elif group in thresholds[column]: # thresholds per group
tmin = thresholds[column][group].get("min", nothresh_min)
tmax = thresholds[column][group].get("max", nothresh_max)
else:
tmin = thresholds[column].get("min", nothresh_min)
tmax = thresholds[column].get("max", nothresh_max)
# Replace None with nothresh
tmin = nothresh_min if tmin is None else tmin
tmax = nothresh_max if tmax is None else tmax
# Plot line and shading
tick = ticks[j]
x = [tick - 0.5, tick + 0.5]
min_line = ax.plot(x, [tmin] * 2, color="red", linestyle="--")[0]
max_line = ax.plot(x, [tmax] * 2, color="red", linestyle="--")[0]
min_shade = ax.add_patch(Rectangle((x[0], ymin), x[1] - x[0], tmin - ymin, color="grey", alpha=0.2, linewidth=0)) # starting at lower left with positive height
max_shade = ax.add_patch(Rectangle((x[0], ymax), x[1] - x[0], tmax - ymax, color="grey", alpha=0.2, linewidth=0)) # starting at upper left with negative height
# Add slider to control thresholds
if is_interactive:
slider = ipywidgets.FloatRangeSlider(description=group, min=data_min, max=data_max,
value=[tmin, tmax], # initial value
continuous_update=False)
slider.observe(functools.partial(_update_thresholds,
fig=fig,
min_line=min_line,
min_shade=min_shade,
max_line=max_line,
max_shade=max_shade), names=["value"])
slider_list.append(slider)
if groupby is not None:
slider_dict[column][group] = slider
else:
slider_dict[column] = slider
ax.set_ylim(ymin, ymax) # set ylim back to original after plotting thresholds
# Link sliders together
if is_interactive:
if len(slider_list) > 1:
# Toggle linked sliders
c = ipywidgets.Checkbox(value=global_threshold, description='Global threshold', disabled=False, indent=False)
linkage_dict[column] = _link_sliders(slider_list) if global_threshold is True else None
c.observe(functools.partial(_toggle_linkage,
linkage_dict=linkage_dict,
slider_list=slider_list,
key=column), names=["value"])
box = ipywidgets.VBox([c] + slider_list)
else:
box = ipywidgets.VBox(slider_list) # no tickbox needed if there is only one slider per column
accordion_content.append(box)
fig.tight_layout()
_save_figure(save) # save plot; can be overwritten if thresholds are changed
# Assemble accordion with different measures
if is_interactive:
accordion = ipywidgets.Accordion(children=accordion_content, selected_index=None)
for i in range(len(columns)):
accordion.set_title(i, columns[i])
fig.canvas.header_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.resizable = True
fig.canvas.width = "auto"
# Hack to force the plot to show
# reference: https://github.com/matplotlib/ipympl/issues/290
fig.canvas._handle_message(fig.canvas, {'type': 'send_image_mode'}, [])
fig.canvas._handle_message(fig.canvas, {'type': 'refresh'}, [])
fig.canvas._handle_message(fig.canvas, {'type': 'initialized'}, [])
fig.canvas._handle_message(fig.canvas, {'type': 'draw'}, [])
fig.canvas.draw()
figure = ipywidgets.HBox([accordion, fig.canvas]) # Setup box to hold all widgets
else:
figure = fig # non interactive figure
return (figure, slider_dict)
[docs]
@beartype
def get_slider_thresholds(slider_dict: dict) -> dict:
"""Get thresholds from sliders.
Parameters
----------
slider_dict : dict
Dictionary of sliders in the format 'slider_dict[column][group] = slider' or 'slider_dict[column] = slider' if no grouping.
Returns
-------
dict
dict in the format threshold_dict[column][group] = {"min": <min_threshold>, "max": <max_threshold>} or
threshold_dict[column] = {"min": <min_threshold>, "max": <max_threshold>} if no grouping
"""
threshold_dict = {}
for measure in slider_dict:
threshold_dict[measure] = {}
if isinstance(slider_dict[measure], dict): # thresholds for groups
for group in slider_dict[measure]:
slider = slider_dict[measure][group]
threshold_dict[measure][group] = {"min": slider.value[0], "max": slider.value[1]}
# Check if all groups have the same thresholds
mins = set([d["min"] for d in threshold_dict[measure].values()])
maxs = set([d["max"] for d in threshold_dict[measure].values()])
# Set overall threshold if individual sliders are similar
if len(mins) == 1 and len(maxs) == 1:
threshold_dict[measure] = threshold_dict[measure][group] # takes the last group from the previous for loop
else: # One threshold for measure
slider = slider_dict[measure]
threshold_dict[measure] = {"min": slider.value[0], "max": slider.value[1]}
return threshold_dict