"""Functions for plotting clustering results e.g. UMAPs colored by clusters."""
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import warnings
from beartype import beartype
from beartype.typing import Literal, Tuple, Optional, Any
import sctoolbox.utils as utils
from sctoolbox.plotting.general import _save_figure
import sctoolbox.utils.decorator as deco
from sctoolbox._settings import settings
logger = settings.logger
[docs]
@deco.log_anndata
@beartype
def search_clustering_parameters(adata: sc.AnnData,
method: Literal["leiden", "louvain"] = "leiden",
resolution_range: Tuple[float | int, float | int, float | int] = (0.1, 1, 0.1),
embedding: str = "X_umap",
ncols: int = 3,
verbose: bool = True,
save: Optional[str] = None,
**kwargs: Any) -> np.ndarray:
"""
Plot a grid of different resolution parameters for clustering.
Parameters
----------
adata : sc.AnnData
Annotated data matrix object.
method : str, default: "leiden"
Clustering method to use. Can be one of 'leiden' or 'louvain'.
resolution_range : Tuple[float | int, float | int, float | int], default: (0.1, 1, 0.1)
Range of 'resolution' parameter values to test. Must be a tuple in the form (min, max, step).
embedding : str, default: "X_umap".
Embedding method to use. Must be a key in adata.obsm. If not, will try to use f"X_{embedding}".
ncols : int, default: 3
Number of columns in the grid.
verbose : bool, default: True
Print progress to console.
save : Optional[str], default None
Path to save figure.
**kwargs : Any
Keyword arguments to be passed to sc.pl.embedding.
Returns
-------
axarr : np.ndarray
Array of axes objects containing the plot(s).
Raises
------
ValueError
If step is lager than max - min
KeyError
If embedding is not found in adata.obsm.
Examples
--------
.. plot::
:context: close-figs
pl.search_clustering_parameters(adata, method='louvain', resolution_range=(0.1, 2, 0.2), embedding='X_umap', ncols=3, verbose=True, save=None)
"""
# Check validity of parameters
res_min, res_max, res_step = resolution_range
if res_step > res_max - res_min:
raise ValueError("'step' of resolution_range is larger than 'max' - 'min'. Please adjust.")
# Check that coordinates for embedding is available in .obsm
if embedding not in adata.obsm:
embedding = f"X_{embedding}"
if embedding not in adata.obsm:
raise KeyError(f"The embedding '{embedding}' was not found in adata.obsm. Please adjust this parameter.")
# Check that method is valid
if method == "leiden":
cl_function = sc.tl.leiden
elif method == "louvain":
cl_function = sc.tl.louvain
# Setup parameters to loop over
res_min, res_max, res_step = resolution_range
resolutions = np.arange(res_min, res_max, res_step)
resolutions = np.around(resolutions, 2)
# Figure with given number of cols
ncols = min(ncols, len(resolutions)) # number of resolutions caps number of columns
nrows = int(np.ceil(len(resolutions) / ncols))
fig, axarr = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
axarr = np.array(axarr).reshape((-1, 1)) if ncols == 1 else axarr # reshape 1-column array
axarr = np.array(axarr).reshape((1, -1)) if nrows == 1 else axarr # reshape 1-row array
axes = axarr.flatten()
for i, res in enumerate(resolutions):
if verbose is True:
logger.info(f"Plotting umap for resolution={res} ({i+1} / {len(resolutions)})")
# Run clustering
key_added = method + "_" + str(round(res, 2))
cl_function(adata, resolution=res, key_added=key_added)
adata.obs[key_added] = utils.rename_categories(adata.obs[key_added]) # rename to start at 1
n_clusters = adata.obs[key_added].nunique()
# Plot embedding
title = f"Resolution: {res} (clusters: {n_clusters})\ncolumn name: {key_added}"
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, message="No data for colormapping provided via 'c'*")
sc.pl.embedding(adata, embedding, color=key_added, ax=axes[i], legend_loc="on data", title=title, show=False, **kwargs)
# Hide plots not filled in
for ax in axes[len(resolutions):]:
ax.axis('off')
plt.tight_layout()
_save_figure(save)
return axarr
[docs]
@deco.log_anndata
@beartype
def marker_gene_clustering(adata: sc.AnnData,
groupby: str,
marker_genes_dict: dict[str, list[str]],
show_umap: bool = True,
save: Optional[str] = None,
figsize: Optional[Tuple[float | int, float | int]] = None,
**kwargs: Any) -> list:
"""
Plot an overview of marker genes and clustering.
Parameters
----------
adata : sc.AnnData
Annotated data matrix.
groupby : str
Key in `adata.obs` for which to plot the clustering.
marker_genes_dict : dict[str, list[str]]
Dictionary of marker genes to plot. Keys are the names of the groups and values are lists of marker genes.
show_umap : bool, default: True
Whether to show a UMAP plot on the left.
save : Optional[str], default: None
If given, save the figure to this path.
figsize : Tuple[float | int, float | int], default: None
Size of the figure. If `None`, use default size.
**kwargs : Any
Keyword arguments to be passed to sc.pl.dotplot.
Returns
-------
axarr : list
List of axes objects containing the plot(s).
Examples
--------
.. plot::
:context: close-figs
marker_genes_dict = {"S": ["PCNA"], "G2M": ["HMGB2"]}
pl.marker_gene_clustering(adata, "phase", marker_genes_dict, show_umap=True, save=None, figsize=None)
"""
i = 0
if show_umap:
figsize = (12, 6) if figsize is None else figsize
fig, axarr = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'width_ratios': [1, 2]})
# Plot UMAP colored by groupby on the left
sc.pl.umap(adata, color=groupby, ax=axarr[0], legend_loc="on data", show=False)
axarr[i].set_aspect('equal')
i += 1
else:
figsize = (6, 6) if figsize is None else figsize
fig, axarr = plt.subplots(1, 1, figsize=figsize)
axarr = [axarr] # Make sure axarr can be indexed
# Make sure all genes are in the data
marker_genes_dict = utils.check_marker_lists(adata, marker_genes_dict)
# Plot marker gene expression on the right
ax = sc.pl.dotplot(adata, marker_genes_dict, groupby=groupby, show=False, dendrogram=True, ax=axarr[i], **kwargs)
ax["mainplot_ax"].set_ylabel(groupby)
ax["mainplot_ax"].set_xticklabels(ax["mainplot_ax"].get_xticklabels(), ha="right", rotation=45)
for text in ax["gene_group_ax"]._children:
text._rotation = 45
text._horizontalalignment = "left"
fig.tight_layout()
plt.subplots_adjust(wspace=0.2)
# Save figure
_save_figure(save)
return list(axarr)