"""Funtions of different single cell embeddings e.g. UMAP, PCA, tSNE."""
import multiprocessing as mp
import warnings
import scanpy as sc
import numpy as np
import pandas as pd
import scipy.stats
from scipy.sparse import issparse
import itertools
import seaborn as sns
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import plotly as po
import plotly.graph_objects as go
from numba import errors as numba_errors
from beartype import beartype
from beartype.typing import Literal, Tuple, Optional, Union, Any, List, Annotated
from beartype.vale import Is
import numpy.typing as npt
import sctoolbox.utils as utils
import sctoolbox.tools as tools
from sctoolbox.plotting.general import _save_figure, _make_square, boxplot
import sctoolbox.utils.decorator as deco
from sctoolbox._settings import settings
logger = settings.logger
#############################################################################
# Utilities #
#############################################################################
[docs]
@beartype
def sc_colormap() -> matplotlib.colors.ListedColormap:
"""Get a colormap with 0-count cells colored grey (to use for embeddings).
Returns
-------
cmap : matplotlib.colors.ListedColormap
Colormap with 0-count cells colored grey.
"""
# Custom colormap for single cells
color_cmap = cm.get_cmap('Reds', 200)
newcolors = color_cmap(np.linspace(0.2, 0.9, 200))
newcolors[0, :] = colors.to_rgba("lightgrey") # count 0 = grey
sc_cmap = ListedColormap(newcolors)
return sc_cmap
[docs]
def grey_colormap() -> matplotlib.colors.ListedColormap:
"""Get a colormap with grey-scale colors, but without white to still show cells.
Returns
-------
cmap : matplotlib.colors.ListedColormap
Grey-scale colormap.
"""
color_cmap = cm.get_cmap('Greys', 200)
newcolors = color_cmap(np.linspace(0.2, 1, 200))
cmap = ListedColormap(newcolors)
return cmap
[docs]
@deco.log_anndata
@beartype
def flip_embedding(adata: sc.AnnData, key: str = "X_umap", how: Literal["vertical", "horizontal"] = "vertical"):
"""Flip the embedding in adata.obsm[key] along the given axis.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
key : str, default "X_umap"
Key in adata.obsm to flip.
how : Literal["vertical", "horizontal"], default "vertical"
Axis to flip along. Can be "vertical" (flips up/down) or "horizontal" (flips left/right).
Raises
------
KeyError
If the given key is not found in adata.obsm.
ValueError
If the given 'how' is not supported.
"""
if key not in adata.obsm:
raise KeyError(f"The given key '{key}' cannot be found in adata.obsm. Please check the key value")
if how == "vertical":
adata.obsm[key][:, 1] = -adata.obsm[key][:, 1]
elif how == "horizontal":
adata.obsm[key][:, 0] = -adata.obsm[key][:, 0]
else:
raise ValueError("The given axis '{0}' is not supported. Please use 'vertical' or 'horizontal'.".format(how))
#####################################################################
# -------------------- UMAP / tSNE embeddings ----------------------#
#####################################################################
@beartype
def _add_contour(x: np.ndarray,
y: np.ndarray,
ax: matplotlib.axes.Axes):
"""Add contour plot to a scatter plot.
Parameters
----------
x : np.ndarray
x-coordinates of the scatter plot.
y : np.ndarray
y-coordinates of the scatter plot.
ax : matplotlib.axes.Axes
Axis object to add the contour plot to.
"""
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
# Peform the kernel density estimate
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])
values = np.vstack([x, y])
kernel = scipy.stats.gaussian_kde(values)
f = np.reshape(kernel(positions).T, X.shape)
# Contour plot
ax.contour(X, Y, f, colors="black", linewidths=0.5)
[docs]
@deco.log_anndata
@beartype
def plot_embedding(adata: sc.AnnData,
method: str = "umap",
color: Optional[list[str | None] | str] = None,
style: Literal["dots", "hexbin", "density"] = "dots",
show_borders: bool = False,
show_contour: bool = False,
show_count: bool = True,
show_title: bool = True,
hexbin_gridsize: int = 30,
shrink_colorbar: float | int = 0.3,
square: bool = True,
save: Optional[str] = None,
**kwargs) -> npt.ArrayLike:
"""Plot a dimensionality reduction embedding e.g. UMAP or tSNE with different style options. This is a wrapper around scanpy.pl.embedding.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix object.
method : str, default "umap"
Dimensionality reduction method to use. Must be a key in adata.obsm, or a method available as "X_<method>" such as "umap", "tsne" or "pca".
color : Optional[str | list[str]], default None
Key for annotation of observations/cells or variables/genes.
style : Literal["dots", "hexbin", "density".], default "dots"
Style of the plot. Must be one of "dots", "hexbin" or "density".
show_borders : bool, default False
Whether to show borders around embedding plot. If False, the borders are removed and a small legend is added to the plot.
show_contour : bool, default False
Whether to show a contour plot on top of the plot.
show_count : bool, default True
Whether to show the number of cells in the plot.
show_title : bool, default True
Whether to show the titles of the plots. If False, the titles are removed and the names are added to the colorbar/legend instead.
hexbin_gridsize : int, default 30
Number of hexbins across plot - higher values give smaller bins. Only used if style="hexbin".
shrink_colorbar : float | int, default 0.3
Shrink the height of the colorbar by this factor.
square : bool, default True
Whether to make the plot square.
save : Optional[str], default None
Filename to save the figure.
**kwargs : arguments
Additional keyword arguments are passed to :func:`scanpy.pl.plot_embedding`.
Returns
-------
axes : npt.ArrayLike
Array of axis objects
Raises
------
KeyError
If the given method is not found in adata.obsm.
ValueError
If the 'components' given is larger than the number of components in the embedding.
Examples
--------
.. plot::
:context: close-figs
pl.plot_embedding(adata, color="louvain", legend_loc="on data")
.. plot::
:context: close-figs
_ = pl.plot_embedding(adata, method="pca", color="n_genes", show_contour=True, show_title=False)
.. plot::
:context: close-figs
_ = pl.plot_embedding(adata, color=['n_genes', 'HES4'], style="hexbin")
.. plot::
:context: close-figs
_ = pl.plot_embedding(adata, method="pca", color=['n_genes', 'HES4'],
style="hexbin", components=["1,2", "2,3"], ncols=2)
.. plot::
:context: close-figs
ax = pl.plot_embedding(adata, color=['n_genes', 'louvain'], style="density")
"""
# Get key in obsm from method
if method in adata.obsm: # method is directly available in obsm
obsm_key = method
elif "X_" + method in adata.obsm: # method is available as "X_<method>"
obsm_key = "X_" + method
else:
raise KeyError(f"The given method '{method}' or 'X_{method}' cannot be found in adata.obsm. The available keys are: {list(adata.obsm.keys())}.")
# ---- Plot embedding for chosen colors ---- #
# get embedding dimensions if passed as a kwarg
# otherwise use defalut dimensions 1 and 2
n_components = adata.obsm[obsm_key].shape[1]
args = locals() # get all arguments passed to function
kwargs = args.pop("kwargs") # split args from kwargs dict
if "components" in kwargs:
dims = kwargs["components"]
if type(dims) is str:
if dims == "all":
dims = ["{0},{1}".format(c[0], c[1]) for c in itertools.combinations(range(1, n_components + 1), 2)] # "1,2", "1,3", "2,3" etc.
else:
dims = [dims]
# Check that dims are valid
for dim in dims:
dim1, dim2 = [int(d.strip()) for d in dim.split(",")]
if dim1 > n_components or dim2 > n_components:
raise ValueError(f"The given component '{dim}' is larger than the number of components in '{obsm_key}' ({n_components}). Please adjust 'components'.")
else:
dims = ["1,2"]
kwargs["components"] = dims # overwrite components kwarg
kwargs["color_map"] = kwargs.get("color_map", sc_colormap()) # set cmap to sc_colormap if not given
parameters = {"color": color,
"basis": method, # sc.pl.embedding can take either "umap" or "X_umap"
"show": False}
if style != "dots":
parameters["alpha"] = 0 # make dots transparent
kwargs.update(parameters)
axarr = sc.pl.embedding(adata, **kwargs)
# if only one axis is returned, convert to list
if not isinstance(axarr, list):
axarr = [axarr]
if not isinstance(color, list):
color = [color]
# Duplicate colors/dimensions if needed
if len(kwargs["components"]) > 1 or len(color) > 1:
color_list = [color[i // len(kwargs["components"])] for i in range(len(axarr))] # color1, color1, color2, color2, etc.
components_list = kwargs["components"] * len(color)
else:
color_list = color
components_list = kwargs["components"]
# ---- Adjust style of individual plots ---- #
for i, ax in enumerate(axarr):
# Establish which color/dimensions are used for current plot
ax_color = color_list[i]
dim1, dim2 = [int(dim.strip()) for dim in components_list[i].split(",")] # (1, 2)
coordinates = adata.obsm[obsm_key][:, [dim1 - 1, dim2 - 1]]
# Remove title
if not show_title:
ax.set_title("")
# Set titles of legend / colorbar / plot
legend = ax.get_legend()
local_axes = ax.figure._localaxes # list of all plot and colorbar axes in figure
has_colorbar = False
if legend is not None: # legend of categorical variables
if not show_title:
legend.set_title(ax_color)
else: # legend of continuous variables
cbar_ax_idx = local_axes.index(ax) + 1 # colorbar is always right after plot
cbar_ax_idx = min(cbar_ax_idx, len(local_axes) - 1) # ensure that idx is within bounds
cbar_ax = local_axes[cbar_ax_idx]
if cbar_ax._label == "<colorbar>":
has_colorbar = True # this ax has colorbar
if not show_title:
cbar_ax.set_title(ax_color)
# Add additional style to plots
if style != "dots":
# Prepare color values
if ax_color is None:
color_values = None
else:
color_values = utils.adata.get_cell_values(adata, ax_color)
# Determine colors to use
cmap = kwargs["color_map"]
cmap = mpl.rcParams["image.cmap"] if cmap is None else cmap # if cmap is None, scanpy uses default cmap for matplotlib
if color_values is None:
cmap = grey_colormap() # if no color values are given, use greyscale to show density
# Plot hexbin/density style if chosen
if style == "hexbin":
# Ensure that color is continuous
if ax_color is not None and has_colorbar is False:
raise ValueError(f"Hexbin style is only supported for continuous variables, and is not possible for the values found in '{ax_color}'. Please set 'style' to 'dots', 'density' or use a continuous variable.")
# Plot hexbin
xlim, ylim = ax.get_xlim(), ax.get_ylim()
hb = ax.hexbin(coordinates[:, 0], y=coordinates[:, 1], C=color_values,
mincnt=1, gridsize=hexbin_gridsize, cmap=cmap)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
# Replace colorbar with hexbin values
if has_colorbar:
ax.figure.colorbar(hb, ax=ax, cax=cbar_ax)
# Set colorbar for number of cells if color is None
if color_values is None:
ax.figure.colorbar(hb, ax=ax, label="Number of cells")
cbar_ax = ax.figure.axes[-1]
has_colorbar = True
# Move colorbar to the correct position in _localaxes
index = local_axes.index(ax)
local_axes.insert(index + 1, local_axes.pop()) # insert colorbar after plot
elif style == "density":
# remove NaN values
if color_values is not None:
is_nan = pd.isna(color_values) # numpy's isnan throws error for string array
color_values = color_values[~is_nan]
coordinates = coordinates[~is_nan]
if ax_color is None:
has_colorbar = True # even non-colored plots have colorbar with density of cells
# Values are continous
if has_colorbar:
if color_values is None:
sns.kdeplot(x=coordinates[:, 0], y=coordinates[:, 1], fill=True, ax=ax, cmap=cmap, thresh=0.01, cbar=True,
cbar_kws={"label": "Cell density"})
cbar_ax = ax.figure.axes[-1] # colorbar was added to last axis
else:
color_values_scaled = (color_values - color_values.min()) / (color_values.max() - color_values.min()) # scale to 0-1
sns.kdeplot(x=coordinates[:, 0], y=coordinates[:, 1], fill=True, weights=color_values_scaled,
ax=ax, cmap=cmap, thresh=0.01, cbar=True, cbar_ax=cbar_ax, cbar_kws={"label": f"Cell density\n(weighted by {ax_color})"})
else: # values are categorical
cat2color = dict(zip(adata.obs[ax_color].cat.categories, adata.uns[ax_color + "_colors"]))
adata_subsets = utils.get_adata_subsets(adata, groupby=ax_color)
for group, adata_sub in adata_subsets.items():
coordinates_sub = adata_sub.obsm[obsm_key][:, [dim1 - 1, dim2 - 1]]
# Plot kde in color from original plot
kde_color = cat2color[group]
collection_len_before = len(ax.collections)
custom_cmap = LinearSegmentedColormap.from_list(f'{group}_cmap', ['lightgrey', kde_color], N=256)
sns.kdeplot(x=coordinates_sub[:, 0], y=coordinates_sub[:, 1], fill=True, ax=ax, cmap=custom_cmap, thresh=0.01)
# Set alpha for each level (enables seeing overlapping groups; lowest level are most see-through)
n_obj_added = len(ax.collections) - collection_len_before
objects = ax.collections[-n_obj_added:]
alpha_list = np.linspace(0.2, 1, len(objects))
for i, obj in enumerate(objects):
obj.set_alpha(alpha_list[i])
# Add contour to plot
if show_contour:
_add_contour(coordinates[:, 0], coordinates[:, 1], ax)
# Remove borders and add small UMAP1/UMAP2 legend
if show_borders is False:
# Remove all spines (axes lines)
for spine in ax.spines.values():
spine.set_visible(False)
# Move x and y-labels to the start of axes
label = ax.xaxis.get_label()
label.set_horizontalalignment('left')
x_lab_pos, y_lab_pos = label.get_position()
label.set_position([0, y_lab_pos])
label = ax.yaxis.get_label()
label.set_horizontalalignment('left')
x_lab_pos, y_lab_pos = label.get_position()
label.set_position([x_lab_pos, 0])
# Draw UMAP coordinate arrows
ymin, ymax = ax.get_ylim()
xmin, xmax = ax.get_xlim()
yrange = ymax - ymin
xrange = xmax - xmin
arrow_len_y = yrange * 0.2
arrow_len_x = xrange * 0.2
ax.annotate("", xy=(xmin, ymin), xytext=(xmin, ymin + arrow_len_y), arrowprops=dict(arrowstyle="<-", shrinkB=0)) # UMAP2 / y-axis
ax.annotate("", xy=(xmin, ymin), xytext=(xmin + arrow_len_x, ymin), arrowprops=dict(arrowstyle="<-", shrinkB=0)) # UMAP1 / x-axis
# Add number of cells to plot
if show_count:
ax.text(0.02, 0.02, f"{adata.n_obs:,} cells",
transform=ax.transAxes,
horizontalalignment='left',
verticalalignment='bottom')
# Adjust aspect ratio
if square:
_make_square(ax)
# Final formatting of colorbar incl. shrink
if has_colorbar:
cbar = cbar_ax._colorbar
plt.colorbar(cbar.mappable, ax=ax, pad=0.01, aspect=30 * shrink_colorbar, shrink=shrink_colorbar, fraction=0.08, anchor=(0.0, 0.0)) # need to plot again to gain control of aspect ratio
new_cbar_ax = ax.figure.axes[-1]
# Carry over title and ylabel
new_cbar_ax.set_title(cbar_ax.get_title(), fontsize=10)
new_cbar_ax.set_ylabel(cbar_ax.get_ylabel(), fontsize=10)
# Set specific cbar style for density plots
if style == "density":
# Adjust colorbar to remove density values
yticks = new_cbar_ax.get_yticks()
new_cbar_ax.set_yticks([yticks[0], yticks[-1]])
new_cbar_ax.set_yticklabels(["low", "high"])
# Move colorbar to the correct position in _localaxes
cbar_idx = local_axes.index(cbar_ax)
new_cbar_idx = local_axes.index(new_cbar_ax)
local_axes[cbar_idx] = new_cbar_ax
local_axes.pop(new_cbar_idx) # remove original idx of new_cbar_ax
# Save figure
_save_figure(save)
return axarr
[docs]
@deco.log_anndata
@beartype
def search_umap_parameters(adata: sc.AnnData,
min_dist_range: Tuple[float | int, float | int, float | int] = (0.2, 0.9, 0.2), # 0.2, 0.4, 0.6, 0.8
spread_range: Tuple[float | int, float | int, float | int] = (0.5, 2.0, 0.5), # 0.5, 1.0, 1.5
color: Optional[str] = None,
n_components: int = 2,
threads: int = 4,
save: Optional[str] = None,
**kwargs: Any) -> np.ndarray:
"""Plot a grid of different combinations of min_dist and spread variables for UMAP plots.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
min_dist_range : Tuple[float | int, float | int, float | int], default: (0.2, 0.9, 0.2)
Range of 'min_dist' parameter values to test. Must be a tuple in the form (min, max, step).
spread_range : Tuple[float | int, float | int, float | int], default (0.5, 2.0, 0.5)
Range of 'spread' parameter values to test. Must be a tuple in the form (min, max, step).
color : Optional[str], default None
Name of the column in adata.obs to color plots by. If None, plots are not colored.
n_components : int, default 2
Number of components in UMAP calculation.
threads : int, default 4
Number of threads to use for UMAP calculation.
save : Optional[str], default None
Path to save the figure to. If None, the figure is not saved.
**kwargs : Any
Additional keyword arguments are passed to :func:`scanpy.tl.umap`.
Returns
-------
np.ndarray
2D numpy array of axis objects
Examples
--------
.. plot::
:context: close-figs
pl.search_umap_parameters(adata, min_dist_range=(0.2, 0.9, 0.2),
spread_range=(2.0, 3.0, 0.5),
color="bulk_labels")
"""
args = locals() # get all arguments passed to function
args["method"] = "umap"
kwargs = args.pop("kwargs") # split args from kwargs dict
return _search_dim_red_parameters(**args, **kwargs)
[docs]
@deco.log_anndata
@beartype
def search_tsne_parameters(adata: sc.AnnData,
perplexity_range: Tuple[int, int, int] = (30, 60, 10),
learning_rate_range: Tuple[int, int, int] = (600, 1000, 200),
color: Optional[str] = None,
threads: int = 4,
save: Optional[str] = None,
**kwargs: Any) -> np.ndarray:
"""Plot a grid of different combinations of perplexity and learning_rate variables for tSNE plots.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
perplexity_range : Tuple[int, int, int], default (30, 60, 10)
tSNE parameter: Range of 'perplexity' parameter values to test. Must be a tuple in the form (min, max, step).
learning_rate_range : Tuple[int, int, int], default (600, 1000, 200)
tSNE parameter: Range of 'learning_rate' parameter values to test. Must be a tuple in the form (min, max, step).
color : Optional[str], default None
Name of the column in adata.obs to color plots by. If None, plots are not colored.
threads : int, default 1
The threads paramerter is currently not supported. Please leave at 1.
This may be fixed in the future.
save : Optional[str], default None (not saved)
Path to save the figure to.
**kwargs : Any
Additional keyword arguments are passed to :func:`scanpy.tl.tsne`.
Returns
-------
np.ndarray
2D numpy array of axis objects
Examples
--------
.. plot::
:context: close-figs
pl.search_tsne_parameters(adata, perplexity_range=(30, 60, 10),
learning_rate_range=(600, 1000, 200),
color="bulk_labels")
"""
args = locals() # get all arguments passed to function
args["method"] = "tsne"
kwargs = args.pop("kwargs")
return _search_dim_red_parameters(**args, **kwargs)
@beartype
def _search_dim_red_parameters(adata: sc.AnnData,
method: Literal["umap", "tsne"],
min_dist_range: Optional[Tuple[int | float, int | float, int | float]] = None, # for UMAP
spread_range: Optional[Tuple[int | float, int | float, int | float]] = None, # for UMAP
perplexity_range: Optional[Tuple[int, int, int]] = None, # for tSNE
learning_rate_range: Optional[Tuple[int, int, int]] = None, # for tSNE
color: Optional[str] = None,
threads: int = 4,
save: Optional[str] = None,
**kwargs: Any) -> np.ndarray:
"""Search different combinations of parameters for UMAP or tSNE and plot a grid of the embeddings.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
method : Literal["umap", "tsne"]
Dimensionality reduction method to use. Must be either 'umap' or 'tsne'.
min_dist_range : Optional[Tuple[int | float, int | float, int | float]], default None
UMAP parameter: Range of 'min_dist' parameter values to test. Must be a tuple in the form (min, max, step).
spread_range : Optional[Tuple[int | float, int | float, int | float]], default None
UMAP parameter: Range of 'spread' parameter values to test. Must be a tuple in the form (min, max, step).
perplexity_range : Optional[Tuple[int, int, int]], default None
tSNE parameter: Range of 'perplexity' parameter values to test. Must be a tuple in the form (min, max, step).
learning_rate_range : Optional[Tuple[int, int, int]], default None
tSNE parameter: Range of 'learning_rate' parameter values to test. Must be a tuple in the form (min, max, step).
color : Optional[str], default None
Name of the column in adata.obs to color plots by. If None, plots are not colored.
threads : int, default 4
Number of threads to use for calculating embeddings. In case of UMAP, the embeddings will be calculated in parallel with each job using 1 thread.
For tSNE, the embeddings are calculated serially, but each calculation uses 'threads' as 'n_jobs' within sc.tl.tsne.
save : Optional[str], default None
Path to save the figure to.
**kwargs : Any
Additional keyword arguments are passed to :func:`scanpy.tl.umap` or :func:`scanpy.tl.tsne`.
Returns
-------
np.ndarray
2D numpy array of axis objects
"""
def get_loop_params(r):
"""Get parameters to loop over."""
# Check validity of range parameters
if len(r) != 4:
raise ValueError(f"The parameter '{r[0]}' must be a tuple in the form (min, max, step)")
if r[3] > r[2] - r[1]:
raise ValueError(f"'step' of '{r[0]}' is larger than 'max' - 'min'. Please adjust.")
return np.around(np.arange(r[1], r[2], r[3]), 2)
# remove data to save memory
adata = utils.get_minimal_adata(adata)
# Allows for all case variants of method parameter
method = method.lower()
if method == "umap":
range_1 = ["min_dist_range"] + list(min_dist_range)
range_2 = ["spread_range"] + list(spread_range)
elif method == "tsne":
range_1 = ["perplexity_range"] + list(perplexity_range)
range_2 = ["learning_rate_range"] + list(learning_rate_range)
# Get tool and plotting function
tool_func = getattr(sc.tl, method)
# Setup loop parameter
loop_params = list()
for r in [range_1, range_2]:
loop_params.append(get_loop_params(r))
# Should the functions be run in parallel?
run_parallel = False
if threads > 1 and method == "umap":
run_parallel = True
# Calculate umap/tsne for each combination of spread/dist
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=numba_errors.NumbaDeprecationWarning) # numba warning for 0.59.0 (only for UMAP)
warnings.filterwarnings("ignore", category=UserWarning, message="In previous versions of scanpy, calling tsne with n_jobs > 1 would use MulticoreTSNE.")
if run_parallel:
pool = mp.Pool(threads)
else:
pbar = utils.get_pbar(len(loop_params[0]) * len(loop_params[1]), f"Computing {method.upper()}s")
# Setup jobs
jobs = {}
for i, r2_param in enumerate(loop_params[1]): # rows
for j, r1_param in enumerate(loop_params[0]): # columns
kwds = {range_1[0].rsplit('_', 1)[0]: r1_param,
range_2[0].rsplit('_', 1)[0]: r2_param,
"copy": True}
if method == "tsne":
kwds["n_jobs"] = threads
kwds |= kwargs # gives the option to overwrite e.g. n_jobs if given in kwargs
logger.debug(f"Running '{method}' with kwds: {kwds}")
if run_parallel:
job = pool.apply_async(tool_func, args=(adata, ), kwds=kwds)
else:
job = tool_func(adata, **kwds) # run the tool function one by one; returns an anndata object
pbar.update(1)
jobs[(i, j)] = job
if run_parallel:
pool.close()
utils.monitor_jobs(jobs, f"Computing {method.upper()}s")
pool.join()
# Figure with rows=spread, cols=dist
fig, axes = plt.subplots(len(loop_params[1]), len(loop_params[0]),
figsize=(4 * len(loop_params[0]), 4 * len(loop_params[1])))
axes = np.array(axes).reshape((-1, 1)) if len(loop_params[0]) == 1 else axes # reshape 1-column array
axes = np.array(axes).reshape((1, -1)) if len(loop_params[1]) == 1 else axes # reshape 1-row array
# Fill in UMAPs
for i, r2_param in enumerate(loop_params[1]): # rows
for j, r1_param in enumerate(loop_params[0]): # columns
if run_parallel:
jobs[(i, j)] = jobs[(i, j)].get()
# Add precalculated UMAP to adata
adata.obsm[f"X_{method}"] = jobs[(i, j)].obsm[f"X_{method}"]
logger.debug(f"Plotting {method} for row={r2_param} and col={r1_param} ({i*len(loop_params[0])+j+1}/{len(loop_params[0])*len(loop_params[1])})")
# Set legend loc for last column
if i == 0 and j == (len(loop_params[0]) - 1):
legend_loc = "left"
else:
legend_loc = "none"
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message="No data for colormapping provided via 'c'*")
sc.pl.embedding(adata, basis="X_" + method, color=color, title='', legend_loc=legend_loc, show=False, ax=axes[i, j])
if j == 0:
axes[i, j].set_ylabel(f"{range_2[0].rsplit('_', 1)[0]}: {r2_param}", fontsize=14)
else:
axes[i, j].set_ylabel("")
if i == 0:
axes[i, j].set_title(f"{range_1[0].rsplit('_', 1)[0]}: {r1_param}", fontsize=14)
axes[i, j].set_xlabel("")
plt.tight_layout()
_save_figure(save)
return axes
#######################################################################################
# -------------------------- Different group embeddings ------------------------------#
#######################################################################################
[docs]
@deco.log_anndata
@beartype
def plot_group_embeddings(adata: sc.AnnData,
groupby: str,
embedding: Literal["umap", "tsne", "pca"] = "umap",
ncols: int = 4,
save: Optional[str] = None,
**kwargs: Any) -> np.ndarray:
"""
Plot a grid of embeddings (UMAP/tSNE/PCA) per group of cells within 'groupby'.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
groupby : str
Name of the column in adata.obs to group by.
embedding : Literal["umap", "tsne", "pca"], default "umap"
Embedding to plot. Must be one of "umap", "tsne", "pca".
ncols : int, default 4
Number of columns in the figure.
save : Optional[str], default None
Path to save the figure.
**kwargs : Any
Additional keyword arguments are passed to :func:`scanpy.pl.umap` or :func:`scanpy.pl.tsne` or :func:`scanpy.pl.pca`.
Returns
-------
np.ndarray
Flat numpy array of axis objects
Examples
--------
.. plot::
:context: close-figs
pl.plot_group_embeddings(adata, 'phase', embedding='umap', ncols=4)
"""
# Get categories
groups = adata.obs[groupby].astype("category").cat.categories
n_groups = len(groups)
# Find out how many rows are needed
ncols = min(ncols, n_groups) # Make sure ncols is not larger than n_groups
nrows = int(np.ceil(len(groups) / ncols))
# Setup subplots
fig, axarr = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 5))
axarr = np.array(axarr).reshape((-1, 1)) if ncols == 1 else axarr
axarr = np.array(axarr).reshape((1, -1)) if nrows == 1 else axarr
axes_list = axarr.flatten()
n_plots = len(axes_list)
# Plot UMAP/tSNE/pca per group
for i, group in enumerate(groups):
ax = axes_list[i]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning, message="Categorical.replace is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, message="In a future version of pandas")
warnings.filterwarnings("ignore", category=UserWarning, message="No data for colormapping provided via 'c'*")
# Plot individual embedding
if embedding == "umap":
sc.pl.umap(adata, color=groupby, groups=group, ax=ax, show=False, legend_loc=None, **kwargs)
elif embedding == "tsne":
sc.pl.tsne(adata, color=groupby, groups=group, ax=ax, show=False, legend_loc=None, **kwargs)
elif embedding == "pca":
sc.pl.pca(adata, color=groupby, groups=group, ax=ax, show=False, legend_loc=None, **kwargs)
ax.set_title(group)
# Hide last empty plots
n_empty = n_plots - n_groups
if n_empty > 0:
for ax in axes_list[-n_empty:]:
ax.set_visible(False)
# Save figure
_save_figure(save)
return axarr
[docs]
@beartype
def compare_embeddings(adata_list: list[sc.AnnData],
var_list: list[str] | str,
embedding: Literal["umap", "tsne", "pca"] = "umap",
adata_names: Optional[list[str]] = None,
**kwargs: Any) -> np.ndarray:
"""Compare embeddings across different adata objects.
Plots a grid of embeddings with the different adatas on the x-axis, and colored variables on the y-axis.
Parameters
----------
adata_list : list[sc.AnnData]
List of AnnData objects to compare.
var_list : list[str] | str
List of variables to color in plot.
embedding : Literal["umap", "tsne", "pca"], default "umap"
Embedding to plot. Must be one of "umap", "tsne" or "pca".
adata_names : Optional[list[str]], default None (adatas will be named adata_1, adata_2, etc.)
List of names for the adata objects. Must be the same length as adata_list or None
**kwargs : Any
Additional arguments to pass to sc.pl.umap/sc.pl.tsne/sc.pl.pca.
Returns
-------
np.ndarray
2D numpy array of axis objects
Raises
------
ValueError
If none of the variables in var_list are found in any of the adata objects.
Examples
--------
.. plot::
:context: close-figs
import scanpy as sc
.. plot::
:context: close-figs
adata1 = sc.datasets.pbmc68k_reduced()
adata2 = sc.datasets.pbmc3k_processed()
adata_list = [adata1, adata2]
var_list = ['n_counts', 'n_cells']
.. plot::
:context: close-figs
pl.compare_embeddings(adata_list, var_list)
"""
embedding = embedding.lower()
# Check the availability of vars in the adata objects
all_vars = set()
for adata in adata_list:
all_vars.update(set(adata.var.index))
all_vars.update(set(adata.obs.columns))
# Subset var list to those available in any of the adata objects
if isinstance(var_list, str):
var_list = [var_list]
not_found = set(var_list) - all_vars
if len(not_found) == len(var_list):
raise ValueError("None of the variables from var_list were found in the adata objects.")
elif len(not_found) > 0:
logger.warning(f"The following variables from var_list were not found in any of the adata objects: {list(not_found)}. These will be excluded.")
var_list = [var for var in var_list if var in all_vars]
# Setup plot grid
n_adata = len(adata_list)
n_var = len(var_list)
fig, axes = plt.subplots(n_var, n_adata, figsize=(4 * n_adata, 4 * n_var))
# Fix indexing
n_cols = n_adata
n_rows = n_var
axes = np.array(axes).reshape((-1, 1)) if n_cols == 1 else axes # Fix indexing for one column figures
axes = np.array(axes).reshape((1, -1)) if n_rows == 1 else axes # Fix indexing for one row figures
if adata_names is None:
adata_names = [f"adata_{n+1}" for n in range(len(adata_list))]
# code for coloring single cell expressions?
# import matplotlib.colors as clr
# cmap = clr.LinearSegmentedColormap.from_list('custom umap', ['#f2f2f2', '#ff4500'], N=256)
for i, adata in enumerate(adata_list):
# Available vars for this adata
available = set(adata.var.index)
available.update(set(adata.obs.columns))
for j, var in enumerate(var_list):
# Check if var is available for this specific adata
if var not in available:
print(f"Variable '{var}' was not found in adata object '{adata_names[i]}'. Skipping coloring.")
var = None
if embedding == "umap":
sc.pl.umap(adata, color=var, show=False, ax=axes[j, i], **kwargs)
elif embedding == "tsne":
sc.pl.tsne(adata, color=var, show=False, ax=axes[j, i], **kwargs)
elif embedding == "pca":
sc.pl.pca(adata, color=var, show=False, ax=axes[j, i], **kwargs)
# Set y-axis label
if i == 0:
axes[j, i].set_ylabel(var)
else:
axes[j, i].set_ylabel("")
# Set title
if j == 0:
axes[j, i].set_title(list(adata_names)[i])
else:
axes[j, i].set_title("")
axes[j, i].set_xlabel("")
_make_square(axes[j, i])
# fig.tight_layout()
return axes
#######################################################################################
# ---------------------------------- 3D UMAP -----------------------------------------#
#######################################################################################
@beartype
def _get_3d_dotsize(n: int) -> int:
"""Get the optimal plotting dotsize for a given number of points."""
if n < 1000:
return 12
elif n < 10000:
return 8
else:
return 3
[docs]
@deco.log_anndata
@beartype
def plot_3D_UMAP(adata: sc.AnnData,
color: str,
save: str,
**kwargs: Any) -> None:
"""Save 3D UMAP plot to a html file.
Parameters
----------
adata : sc.AnnData
Annotated data matrix.
color : str
Variable to color in plot. Must be a column in adata.obs or an index in adata.var.
save : str
Save prefix. Plot will be saved to <save>.html.
**kwargs : Any
Additional keyword arguments are passed to :func:`plotly.graph_objects.Scatter3d`.
Raises
------
KeyError
If the given 'color' attribute was not found in adata.obs columns or adata.var index.
Examples
--------
.. plot::
:context: close-figs
min_dist = 0.3
spread = 2.5
sc.tl.umap(adata, min_dist=min_dist, spread=spread, n_components=3)
.. plot::
:context: close-figs
pl.plot_3D_UMAP(adata, color="louvain", save="my3d_umap")
This will create an .html-file with the interactive 3D UMAP: :download:`my3d_umap.html <my3d_umap.html>`
"""
n_cells = len(adata.obs)
size = _get_3d_dotsize(n_cells)
# Get coordinates
coordinates = adata.obsm['X_umap'][:, :3]
df = pd.DataFrame(coordinates)
df.columns = ["x", "y", "z"]
# Create plot
po.offline.init_notebook_mode(connected=True) # prints a dict when not run in notebook
fig = go.Figure()
# Plot per group in obs
if color in adata.obs.columns and isinstance(adata.obs[color][0], str):
df["category"] = adata.obs[color].values # color should be interpreted as a categorical variable
categories = df["category"].unique()
n_groups = len(categories)
color_list = sns.color_palette("Set1", n_groups)
color_list = list(map(colors.to_hex, color_list)) # convert to hex
for i, name in enumerate(categories):
df_sub = df[df['category'] == name]
go_plot = go.Scatter3d(x=df_sub['x'],
y=df_sub['y'],
z=df_sub['z'],
name=name,
hovertemplate=name + '<br>(' + str(len(df_sub)) + ' cells)<extra></extra>',
showlegend=True,
mode='markers',
marker=dict(size=size,
color=[color_list[i] for _ in range(len(df_sub))],
opacity=0.8),
**kwargs)
fig.add_trace(go_plot)
# Plot a gene expression
else:
# Color is a value column in obs
if color in adata.obs.columns:
color_values = adata.obs[color]
# color is a gene
elif color in adata.var.index:
color_idx = list(adata.var.index).index(color)
color_values = adata.X[:, color_idx]
color_values = color_values.todense().A1 if issparse(color_values) else color_values
# color was not found
else:
raise KeyError("The given 'color' attribute was not found in adata.obs columns or adata.var index.")
# Plot 3d with colorbar
go_plot = go.Scatter3d(x=df['x'],
y=df['y'],
z=df['z'],
name='Expression of ' + color,
hovertemplate='Expression of ' + color + '<br>(' + str(len(df)) + ' cells)<extra></extra>',
showlegend=True,
mode='markers',
marker=dict(size=size,
color=color_values,
colorscale='Viridis',
colorbar=dict(thickness=20, lenmode='fraction', len=0.75),
opacity=0.8))
fig.add_trace(go_plot)
# Finalize plot
fig.update_layout(legend={'itemsizing': 'constant'}, legend_title_text='<br><br>' + color)
fig.update_scenes(xaxis=dict(showspikes=False),
yaxis=dict(showspikes=False),
zaxis=dict(showspikes=False))
fig.update_layout(scene=dict(xaxis_title='UMAP1',
yaxis_title='UMAP2',
zaxis_title='UMAP3'))
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
# Save to file
if isinstance(save, str):
path = settings.full_figure_prefix + save + ".html"
fig.write_html(path)
logger.info(f"Plot written to '{path}'")
else:
logger.error("Please specify save parameter for html export")
[docs]
@deco.log_anndata
@beartype
def umap_marker_overview(adata: sc.AnnData,
markers: list[str] | str,
ncols: int = 3,
figsize: Optional[Tuple[int, int]] = None,
save: Optional[str] = None,
cbar_label: str = "Relative expr.",
**kwargs: Any) -> list:
"""Plot a pretty grid of UMAPs with marker gene expression.
Parameters
----------
adata : sc.AnnData
Annotated data matrix.
markers : list[str] | str
List of markers or singel marker
ncols : int, default 3
Number of columns in grid.
figsize : Optional[Tuple[int, int]], default None
Tuple of figure size.
save : Optional[str], default None
If not None save plot under given name.
cbar_label : str, default "Relative expr."
Colorbar label
**kwargs : Any
Additional parameter for scanpy.pl.umap()
Returns
-------
list
List of axis objects
"""
if isinstance(markers, str):
markers = [markers]
# Find out how many rows we need
n_markers = len(markers)
nrows = int(np.ceil(n_markers / ncols))
if figsize is None:
figsize = (ncols * 3, nrows * 3)
fig, axarr = plt.subplots(ncols=ncols, nrows=nrows, figsize=figsize)
params = {"cmap": sc_colormap(),
"ncols": ncols,
"frameon": False}
params.update(**kwargs)
axes_list = axarr.flatten()
for i, marker in enumerate(markers):
ax = axes_list[i]
_ = sc.pl.umap(adata,
color=marker,
show=False,
colorbar_loc=None,
ax=ax,
**params)
# Add title to upper left corner
# ax.text(0, 1, marker, transform=ax.transAxes,
# horizontalalignment='left',
# verticalalignment='top')
# Hide axes not used
for ax in axes_list[len(markers):]:
ax.set_visible(False)
axes_list = axes_list[:len(markers)]
# Add colorbar next to the last plot
cax = fig.add_axes([0, 0, 1, 1]) # dummy size, will be resized
lastax_pos = axes_list[len(markers) - 1].get_position() # get the position of the last axis
newpos = [lastax_pos.x1 * 1.1, lastax_pos.y0, lastax_pos.width * 0.1, lastax_pos.height * 0.5]
cax.set_position(newpos) # set a new position
cbar = plt.colorbar(cm.ScalarMappable(cmap=params["cmap"]), cax=cax, label=cbar_label)
cbar.set_ticks([])
cbar.outline.set_visible(False) # remove border of colorbar
# Make plots square
for ax in axes_list:
_make_square(ax)
# Save figure if chosen
_save_figure(save)
return list(axes_list)
# See https://github.com/beartype/beartype/issues/347
_VALID_PLOTS = frozenset(("UMAP", "tSNE", "PCA", "PCA-var", "LISI"))
ListOfValidPlots = Annotated[List[Literal["UMAP", "tSNE", "PCA", "PCA-var", "LISI"]], Is[
lambda lst: all(item in _VALID_PLOTS for item in lst)]]
[docs]
@beartype
def anndata_overview(adatas: dict[str, sc.AnnData],
color_by: str | list[str],
plots: Union[ListOfValidPlots,
Literal["UMAP", "tSNE", "PCA", "PCA-var", "LISI"]] = ["PCA", "PCA-var", "UMAP", "LISI"],
figsize: Optional[Tuple[int, int]] = None,
max_clusters: int = 20,
output: Optional[str] = None,
dpi: int = 300,
**kwargs: Any) -> npt.ArrayLike:
"""Create a multipanel plot comparing PCA/UMAP/tSNE/(...) plots for different adata objects.
Parameters
----------
adatas : dict[str, sc.AnnData]
Dict containing an anndata object for each batch correction method as values. Keys are the name of the respective method.
E.g.: {"bbknn": anndata}
color_by : str | list[str]
Name of the .obs column to use for coloring in applicable plots (e.g. for UMAP or PCA).
plots : Union[list[Literal["UMAP", "tSNE", "PCA", "PCA-var", "LISI"]],
Literal["UMAP", "tSNE", "PCA", "PCA-var", "LISI"]], default ["PCA", "PCA-var", "UMAP", "LISI"]
Decide which plots should be created. Options are ["UMAP", "tSNE", "PCA", "PCA-var", "LISI"]
Note: List order is forwarded to plot.
- UMAP: Plots the UMAP embedding of the data.
- tSNE: Plots the tSNE embedding of the data.
- PCA: Plots the PCA embedding of the data.
- PCA-var: Plots the variance explained by each PCA component.
- LISI: Plots the distribution of any "LISI_score*" scores available in adata.obs
figsize : Optional[Tuple[int, int]], default None
Size of the plot in inch. Defaults to automatic size based on number of columns/rows.
max_clusters : int, default 20
Maximum number of clusters to show in legend.
output : Optional[str], default None
Path to plot output file.
dpi : int, default 300
Dots per inch for output
**kwargs : Any
Additional keyword arguments are passed to :func:`scanpy.pl.umap`, :func:`scanpy.pl.tsne` or :func:`scanpy.pl.pca`.
Returns
-------
axes : npt.ArrayLike
Array of matplotlib.axes.Axes objects created by matplotlib.
Raises
------
ValueError
If any of the adatas is not of type anndata.AnnData.
Examples
--------
.. plot::
:context: close-figs
adatas = {} # dictionary of adata objects
adatas["standard"] = adata
adatas["parameter1"] = sc.tl.umap(adata, min_dist=1, copy=True)
adatas["parameter2"] = sc.tl.umap(adata, min_dist=2, copy=True)
pl.anndata_overview(adatas, color_by="louvain", plots=["PCA", "PCA-var", "UMAP"])
"""
if not isinstance(color_by, list):
color_by = [color_by]
if not isinstance(plots, list):
plots = [plots]
# ---- helper functions ---- #
def annotate_row(ax, plot_type):
"""Annotate row in figure."""
# https://stackoverflow.com/a/25814386
ax.annotate(plot_type,
xy=(0, 0.5),
xytext=(-ax.yaxis.labelpad - 5, 0),
xycoords=ax.yaxis.label,
textcoords='offset points',
size=ax.title._fontproperties._size * 1.2, # increase title fontsize
horizontalalignment='right',
verticalalignment='center',
fontweight='bold')
# ---- checks ---- #
# dict contains only anndata
wrong_type = {k: type(v) for k, v in adatas.items() if not isinstance(v, sc.AnnData)}
if wrong_type:
raise ValueError(f"All items in 'adatas' parameter have to be of type AnnData. Found: {wrong_type}")
# check if color_by exists in anndata.obs
for color_group in color_by:
for name, adata in adatas.items():
if color_group not in adata.obs.columns and color_group not in adata.var.index:
raise ValueError(f"Couldn't find column '{color_group}' in the adata.obs or adata.var for '{name}'")
# ---- plotting ---- #
# setup subplot structure
row_count = {"PCA-var": 1, "LISI": 1} # all other plots count for len(color_by)
rows = sum([row_count.get(plot, len(color_by)) for plot in plots]) # the number of rows in output plot
cols = len(adatas)
figsize = figsize if figsize is not None else (2 + cols * 4, rows * 4)
fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=figsize) # , constrained_layout=True)
axs = axs.flatten() if rows > 1 or cols > 1 else [axs] # flatten to 1d array per row
# Fill in plots for every adata across plot type and color_by
ax_idx = 0
LISI_axes = []
for plot_type in plots:
for color in color_by:
# Iterate over adatas to find all possible categories for 'color'
categories = []
for adata in adatas.values():
if color in adata.obs.columns: # color can also be an index in var
categories += list(adata.obs[color].unique())
categories = sorted(list(set(categories)))
# Create color palette equal for all columns
if len(categories) > 0:
colors = sns.color_palette("tab10", len(categories))
color_dict = dict(zip(categories, colors))
else:
color_dict = None # use default color palette
# Plot for each adata (one row)
for i, (name, adata) in enumerate(adatas.items()):
ax = axs[ax_idx]
# Only show legend for the last column
if i == len(adatas) - 1:
legend_loc = "right margin"
# Disable colorbar for continuous values (will be re-added later)
colorbar_loc = "right" if color in adata.obs.select_dtypes(exclude="number").columns else None
else:
legend_loc = "none"
colorbar_loc = None
# add row label to first plot
if i == 0:
annotate_row(ax, plot_type)
# Collect options for plotting
embedding_kwargs = {"color": color,
"palette": color_dict, # only used for categorical color
"title": "",
"legend_loc": legend_loc,
"colorbar_loc": colorbar_loc,
"show": False}
embedding_kwargs.update(**kwargs) # overwrite with kwargs from user
# Plot depending on type
if plot_type == "PCA-var":
plot_pca_variance(adata, ax=ax, show_cumulative=False) # this plot takes no color
elif plot_type == "LISI":
# Find any LISI scores in adata.obs
lisi_columns = [col for col in adata.obs.columns if col.startswith("LISI_score")]
if len(lisi_columns) == 0:
e = f"No LISI scores found in adata.obs for '{name}'"
e += "Please run 'sctoolbox.tools.wrap_batch_evaluation()' or remove LISI from the plots list"
raise ValueError(e)
# Plot LISI scores
boxplot(adata.obs[lisi_columns], ax=ax)
LISI_axes.append(ax)
else:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message="No data for colormapping provided via 'c'*")
if plot_type == "UMAP":
sc.pl.umap(adata, ax=ax, **embedding_kwargs)
elif plot_type == "tSNE":
sc.pl.tsne(adata, ax=ax, **embedding_kwargs)
elif plot_type == "PCA":
sc.pl.pca(adata, ax=ax, **embedding_kwargs)
# Set title for the legend (for categorical color)
if hasattr(ax, "legend_") and ax.legend_ is not None:
# Get current legend and remove
lines, labels = ax.get_legend_handles_labels()
ax.get_legend().remove()
# Replot legend with limited number of clusters
per_column = 10
n_clusters = min(max_clusters, len(lines))
n_cols = int(np.ceil(n_clusters / per_column))
if mpl.__version__ > '3.6.0':
ax.legend(lines[:max_clusters], labels[:max_clusters],
title=color, ncols=n_cols, frameon=False,
bbox_to_anchor=(1.05, 0.5),
loc=6)
else:
ax.legend(lines[:max_clusters], labels[:max_clusters],
title=color, ncol=n_cols, frameon=False,
bbox_to_anchor=(1.05, 0.5),
loc=6)
# Adjust colorbars (for continuous color)
elif i == len(adatas) - 1 and (color in adata.obs.select_dtypes(include="number").columns or color in adata.var.index):
# Replace native scanpy colorbar with self-made one to gain the abililty to set a label
# Size parameter values are taken from scanpy: https://github.com/scverse/scanpy/blob/383a61b2db0c45ba622f231f01d0e7546d99566b/scanpy/plotting/_tools/scatterplots.py#L456
if len(ax.collections) > 0:
plt.colorbar(ax.collections[0], pad=0.01, fraction=0.08, aspect=30, ax=ax, orientation='vertical', label=color)
_make_square(ax)
ax_idx += 1 # increment index for next plot
if plot_type in row_count:
break # If not dependent on color; break off early from color_by loop
# Set common y-axis limit for LISI plots
if len(LISI_axes) > 0:
min_y, max_y = np.inf, -np.inf
for ax in LISI_axes:
ylim = ax.get_ylim()
min_y = min(min_y, ylim[0])
max_y = max(max_y, ylim[1])
for ax in LISI_axes:
ax.set_ylim(min_y, max_y) # scale all plots to same y-limits
LISI_axes[0].set_ylabel("Unique batch labels in cell neighborhood")
# Finalize axes titles and labels
for i, name in enumerate(adatas):
fontsize = axs[i].title._fontproperties._size * 1.2 # increase title fontsize
axs[i].set_title(name, size=fontsize, fontweight='bold') # first rows should have the adata names
# save
_save_figure(output, dpi=dpi)
return axs
[docs]
@deco.log_anndata
@beartype
def plot_pca_variance(adata: sc.AnnData,
method: str = "pca",
n_pcs: int = 20,
selected: Optional[List[int]] = None,
show_cumulative: bool = True,
n_thresh: Optional[int] = None,
corr_plot: Optional[Literal["spearmanr", "pearsonr"]] = None,
corr_on: Literal["obs", "var"] = "obs",
corr_thresh: Optional[float] = None,
ax: Optional[matplotlib.axes.Axes] = None,
save: Optional[str] = None,
sel_col: str = "grey",
om_col: str = "lightgrey"
) -> matplotlib.axes.Axes:
"""Plot the pca variance explained by each component as a barplot.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
method : str, default "pca"
Method used for calculating variation. Is used to look for the coordinates in adata.uns[<method>].
n_pcs : int, default 20
Number of components to plot.
selected : Optional[List[int]], default None
Number of components to highlight in the plot.
show_cumulative : bool, default True
Whether to show the cumulative variance explained in a second y-axis.
n_thresh : Optional[int], default None
Enables a vertical threshold line.
corr_plot : Optional[str], default None
Enable correlation plot. Shows highest absolute correlation for each bar.
corr_on : Literal["obs", "var"], default "obs"
Calculate correlation on either observations (adata.obs) or variables (adata.var).
corr_thresh : Optional[float], default None
Enables a red threshold line in the lower plot.
ax : Optional[matplotlib.axes.Axes], default None
Axes object to plot on. If None, a new figure is created.
save : Optional[str], default None (not saved)
Filename to save the figure. If None, the figure is not saved.
sel_col : str, default "grey"
Bar color of selected bars.
om_col : str, default "lightgrey"
Bar color of omitted bars.
Returns
-------
matplotlib.axes.Axes
Axes object containing the plot.
Raises
------
KeyError
If the given method is not found in adata.uns.
Examples
--------
.. plot::
:context: close-figs
pl.plot_pca_variance(adata, method="pca",
n_pcs=20,
selected=[2, 3, 4, 5, 7, 8, 9],
corr_plot="spearmanr")
"""
if ax is None:
_, ax = plt.subplots()
else:
if not type(ax).__name__.startswith("Axes"):
raise ValueError("'ax' parameter needs to be an Axes object. Please check your input.")
if method not in adata.uns:
raise KeyError("The given method '{0}' is not found in adata.uns. Please make sure to run the method before plotting variance.")
# Get variance from object
var_explained = adata.uns[method]["variance_ratio"][:n_pcs]
var_explained = var_explained * 100 # to percent
# Cumulative variance
var_cumulative = np.cumsum(var_explained)
if corr_plot:
# compute correlation coefficients
corrcoefs, _ = tools.correlation_matrix(adata,
which=corr_on,
basis=method,
n_components=n_pcs,
columns=None,
method=corr_plot)
abs_corrcoefs = list(corrcoefs.abs().max(axis=0))
# prepare bar coloring by threshold
if selected:
palette = [sel_col if i in selected else om_col for i in range(1, n_pcs + 1)]
else:
# no threshold
palette = [sel_col] * n_pcs
# hide the initial ax object
ax.set_axis_off()
# get the figure where the plots will be drawn on
fig = ax.get_figure()
# create a gridspec (a manual subplot grid) and position it at the location of the ax object
upper_left, bottom_right = ax.get_position().get_points()
gridspec = fig.add_gridspec(ncols=1,
nrows=2 if corr_plot else 1,
left=upper_left[0],
right=bottom_right[0],
top=bottom_right[1],
bottom=upper_left[1],
hspace=0.1) # set the horizontal space between the plots
axs = [fig.add_subplot(gridspec[0, 0])]
if corr_plot:
axs.append(fig.add_subplot(gridspec[1, 0]))
# share x axis between plots
axs[0].sharex(axs[1])
# Plot barplot of variance
x = list(range(1, len(var_explained) + 1))
sns.barplot(x=x,
y=var_explained,
color="grey",
palette=palette,
ax=axs[0])
axs[0].set_ylabel("Variance explained (%)", fontsize=12)
# Plot cumulative variance
if show_cumulative:
ax2 = axs[0].twinx()
ax2.plot(range(len(var_cumulative)), var_cumulative, color="blue", marker="o", linewidth=1, markersize=3)
ax2.set_ylabel("Cumulative\nvariance explained (%)", color="blue", fontsize=12)
ax2.spines['right'].set_color('blue')
ax2.yaxis.label.set_color('blue')
ax2.tick_params(axis='y', colors='blue')
# Add number of selected as line
if n_thresh:
if show_cumulative:
ylim = ax2.get_ylim()
yrange = ylim[1] - ylim[0]
ax2.set_ylim(ylim[0], ylim[1] + yrange * 0.1) # add 10% to make room for legend of n_seleced line
axs[0].axvline(n_thresh - 0.5, color="red") # , label=f"n components included: {n_selected}")
# axs[0].legend()
# Plot absolute correlation bar plot
if corr_plot:
if corr_thresh:
# add threshold line
axs[1].axhline(corr_thresh, color="red")
sns.barplot(x=x,
y=abs_corrcoefs,
color="grey",
palette=palette,
ax=axs[1])
# add basis text box
axs[1].text(
x=0.95,
y=0.05,
s=f"Based on .{corr_on} columns",
fontsize=12,
bbox={"boxstyle": "Round", "facecolor": "white", "edgecolor": "black", "alpha": 0.5},
horizontalalignment="right",
verticalalignment="bottom",
transform=axs[1].transAxes
)
# Finalize plot
axs[1].set_xlabel('Principal components', fontsize=12, labelpad=10)
axs[1].set_ylabel(f"max( |{corr_plot}| )", fontsize=12)
axs[1].set_ylim([0, 1])
axs[1].set_xticklabels(axs[1].get_xticklabels(), rotation=90, size=7)
axs[1].set_axisbelow(True)
axs[1].invert_yaxis()
axs[1].margins(x=0.01) # space before first and after last bar
axs[0].tick_params(bottom=False, labelbottom=False)
axs[0].margins(x=0.01) # space before first and after last bar
else:
# Finalize plot
axs[0].set_xlabel('Principal components', fontsize=12, labelpad=10)
axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=90, size=7)
axs[0].set_axisbelow(True)
axs[0].margins(x=0.01) # space before first and after last bar
# Save figure
_save_figure(save)
return ax
[docs]
@deco.log_anndata
@beartype
def plot_pca_correlation(adata: sc.AnnData,
which: Literal["obs", "var"] = "obs",
basis: str = "pca",
n_components: int = 10,
columns: Optional[list[str]] = None,
pvalue_threshold: float = 0.01,
method: Literal["spearmanr", "pearsonr"] = "spearmanr",
plot_values: Literal["corrcoefs", "pvalues"] = "corrcoefs",
figsize: Optional[Tuple[int, int]] = None,
title: Optional[str] = None,
save: Optional[str] = None,
**kwargs: Any) -> matplotlib.axes.Axes:
"""
Plot a heatmap of the correlation between dimensionality reduction coordinates (e.g. umap or pca) and the given columns.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
which : Literal["obs", "var"], default "obs"
Whether to use the observations ("obs") or variables ("var") for the correlation.
basis : str, default "pca"
Dimensionality reduction to calculate correlation with. Must be a key in adata.obsm, or a basis available as "X_<basis>" such as "umap", "tsne" or "pca".
n_components : int, default 10
Number of components to use for the correlation.
columns : Optional[list[str]], default None
List of columns to use for the correlation. If None, all numeric columns are used.
pvalue_threshold : float, default 0.01
Threshold for significance of correlation. If the p-value is below this threshold, a star is added to the heatmap.
method : Literal["spearmanr", "pearson"], default "spearmanr"
Method to use for correlation. Must be either "pearsonr" or "spearmanr".
plot_values: Literal["corrcoefs", "pvalues"], default "corrcoefs"
Values which will be used to plot the heatmap, either "corrcoefs" (correlation coefficients) or "pvalues". P-values will be shown as
`np.sign(corrcoefs)*np.log10(p-value)`, the logged p-value with the sign of the corresponding correlation coefficient.
figsize : Optional[Tuple[int, int]], default None
Size of the figure in inches. If None, the size is automatically determined.
title : Optional[str], default None
Title of the plot. If None, no title is added.
save : Optional[str], default None
Filename to save the figure.
**kwargs : Any
Additional keyword arguments are passed to :func:`seaborn.heatmap`.
Returns
-------
ax : matplotlib.axes.Axes
Axes object containing the heatmap.
Examples
--------
.. plot::
:context: close-figs
pl.plot_pca_correlation(adata, which="obs")
.. plot::
:context: close-figs
pl.plot_pca_correlation(adata, basis="umap")
"""
# compute correlation matrix
corrcoefs, pvalues = tools.correlation_matrix(adata=adata,
which=which,
basis=basis,
n_components=n_components,
columns=columns,
method=method)
# decide which values should be shown
if plot_values == "corrcoefs":
table = corrcoefs
elif plot_values == "pvalues":
# log pvalues
table = np.sign(corrcoefs) * np.log10(pvalues)
# prepare annotation shown on the heatmap
annot = table.copy()
annot = annot.applymap(lambda x: str(np.round(x, 2)))
# add stars to significant values
stars = pvalues.applymap(lambda p: "*" if p < pvalue_threshold else "")
annot += stars
# Plot heatmap
figsize = figsize if figsize is not None else (len(corrcoefs.columns) / 1.5, len(corrcoefs) / 1.5)
fig, ax = plt.subplots(figsize=figsize)
if plot_values == "corrcoefs":
# center of cbar is 0
vmin = -1
vmax = 1
elif plot_values == "pvalues":
# infer min and max for cbar from data
vmin = None
vmax = None
ax = sns.heatmap(corrcoefs,
annot=annot,
fmt='',
annot_kws={"fontsize": 9},
cbar_kws={"label": f"{method} ({plot_values})"},
cmap="seismic",
vmin=vmin,
vmax=vmax,
ax=ax,
**kwargs)
ax.set_aspect(0.8)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
# Set size of cbar to the same height as the heatmap
cbar_ax = fig.get_axes()[-1]
ax_pos = ax.get_position()
cbar_pos = cbar_ax.get_position()
cbar_ax.set_position([ax_pos.x1 + 2 * cbar_pos.width, ax_pos.y0,
cbar_pos.width, ax_pos.height])
# Add black borders to axes
for ax_obj in [ax, cbar_ax]:
for _, spine in ax_obj.spines.items():
spine.set_visible(True)
# Add title
if title is not None:
ax.set_title(str(title))
# Save figure
_save_figure(save)
return ax