Source code for sctoolbox.plotting.velocity

"""Plot velocity related figures e.g. pseudo-time heatmap."""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.sparse import issparse
import scanpy as sc

import sctoolbox.utils as utils
import sctoolbox.utils.decorator as deco
from sctoolbox.plotting.general import _save_figure

# type hint imports
from beartype.typing import Optional, Tuple, Any
from beartype import beartype
import matplotlib


####################################################################################################
#                        Plots correlating gene expression with pseudotime                         #
####################################################################################################

[docs] @deco.log_anndata @beartype def pseudotime_heatmap(adata: sc.AnnData, genes: list[str], sortby: Optional[str] = None, layer: Optional[str] = None, figsize: Optional[Tuple[int | float, int | float]] = None, shrink_cbar: int | float = 0.5, title: Optional[str] = None, save: Optional[str] = None, **kwargs: Any) -> matplotlib.axes.Axes: """ Plot heatmap of genes along pseudotime sorted by 'sortby' column in adata.obs. Parameters ---------- adata : sc.AnnData Anndata object genes : list[str] List of genes for heatmap. sortby : Optional[str], default None Sort genes by condition layer : Optional[str], default None Use different layer of anndata object. figsize : Optional[Tuple[int | float, int | float]], default None Tuple of integers setting the heatmap figsize. shrink_cbar : int | float, default 0.5 Shrink color bar by set ratio. title : Optional[str], default None Set title for plot. save : Optional[str], default None Path and name of file to be saved. **kwargs : Any Additional arguments passed to seaborn.heatmap. Returns ------- ax : matplotlib.axes.Axes Axes object containing the plot. """ adata_sub = adata[:, genes].copy() # Sort adata if sortby is not None: obs_sorted = adata_sub.obs.sort_values(sortby) adata_sub = adata_sub[obs_sorted.index, :] # Collect matrix if layer is not None: mat = adata_sub.layers[layer] else: mat = adata_sub.X mat = mat.todense() if issparse(mat) else mat mat = mat.T # pseudotime on x-axis # Convert to pandas dataframe mat = pd.DataFrame(mat) mat.index = genes # z-score normalize per row mat = utils.table_zscore(mat) # Plot heatmap n_genes = len(mat) n_cells = mat.shape[1] nrows = 1 if figsize is None: figsize = (6, n_genes / 5) # Setup figure fig, axarr = plt.subplots(nrows, 1, sharex=True, figsize=figsize) # , height_ratios=(1, len(mat))) axarr = [axarr] if type(axarr).__name__.startswith("Axes") else axarr i = 0 parameters = {"center": 0, "cmap": "bwr"} parameters.update(kwargs) ax = sns.heatmap(mat, ax=axarr[i], yticklabels=True, # make sure all labels are shown cbar_kws={"label": "Expr.z-score", "shrink": shrink_cbar, "anchor": (0, 0), "aspect": 20 * shrink_cbar * 2}, # width of cbar after shrink by adjusting aspect **parameters) ax.set_xticks([]) # remove x-ticks if title is not None: ax.set_title(title) # Draw pseudotime arrow below heatmap ax.annotate('', xy=(0, n_genes + 1), xycoords=ax.transData, xytext=(n_cells, n_genes + 1), arrowprops=dict(arrowstyle="<-", color='black')) ax.text(n_cells / 2, n_genes + 1.2, f"Pseudotime (n={n_cells:,} cells)", transform=ax.transData, ha="center", va="top") # Save figure _save_figure(save) return ax