Source code for sctoolbox.utils.adata

"""anndata.AnnData related functions."""

import numpy as np
import scanpy as sc
from collections.abc import Sequence  # check if object is iterable
from collections import OrderedDict
import scipy
import matplotlib.pyplot as plt
from scipy.sparse import issparse

from beartype.typing import Optional, Any
from beartype import beartype

import sctoolbox.utils.decorator as deco
from sctoolbox._settings import settings
logger = settings.logger


[docs] @beartype def get_adata_subsets(adata: sc.AnnData, groupby: str) -> dict[str, sc.AnnData]: """ Split an anndata object into a dict of sub-anndata objects based on a grouping column. Parameters ---------- adata : sc.AnnData Anndata object to split. groupby : str Column name in adata.obs to split by. Returns ------- dict[str, sc.AnnData] Dictionary of anndata objects in the format {<group1>: anndata, <group2>: anndata, (...)}. Raises ------ ValueError If groupby is not found in `adata.obs.columns`. """ if groupby not in adata.obs.columns: raise ValueError(f"Column '{groupby}' not found in adata.obs") group_names = adata.obs[groupby].astype("category").cat.categories.tolist() adata_subsets = {name: adata[adata.obs[groupby] == name] for name in group_names} logger.debug("Split adata into {} subsets based on column '{}'".format(len(adata_subsets), groupby)) return adata_subsets
[docs] @deco.log_anndata @beartype def add_expr_to_obs(adata: sc.AnnData, gene: str) -> None: """ Add expression of a gene from adata.X to adata.obs as a new column. Parameters ---------- adata : sc.AnnData Anndata object to add expression to. gene : str Gene name to add expression of. Raises ------ Exception If the gene is not found in the adata object. """ boolean = adata.var.index == gene if sum(boolean) == 0: raise Exception(f"Gene {gene} not found in adata.var.index") else: idx = np.argwhere(boolean)[0][0] adata.obs[gene] = adata.X[:, idx].todense().A1
[docs] @deco.log_anndata @beartype def shuffle_cells(adata: sc.AnnData, seed: int = 42) -> sc.AnnData: """ Shuffle cells in an adata object to improve plotting. Otherwise, cells might be hidden due plotting samples in order e.g. sample1, sample2, etc. Parameters ---------- adata : sc.AnnData Anndata object to shuffle cells in. seed : int, default 42 Seed for random number generator. Returns ------- sc.AnnData Anndata object with shuffled cells. """ import random state = random.getstate() random.seed(seed) shuffled_barcodes = random.sample(adata.obs.index.tolist(), len(adata)) adata = adata[shuffled_barcodes] random.setstate(state) # reset random state return adata
[docs] @beartype def get_minimal_adata(adata: sc.AnnData) -> sc.AnnData: """ Return a minimal copy of an anndata object e.g. for estimating UMAP in parallel. Parameters ---------- adata : sc.AnnData Annotated data matrix. Returns ------- sc.AnnData Minimal copy of anndata object. """ adata_minimal = adata.copy() adata_minimal.X = None adata_minimal.layers = None adata_minimal.raw = None return adata_minimal
[docs] @beartype def load_h5ad(path: str) -> sc.AnnData: """ Load an anndata object from .h5ad file. Parameters ---------- path : str Name of the file to load the anndata object. NOTE: Uses the internal 'sctoolbox.settings.adata_input_dir' + 'sctoolbox.settings.adata_input_prefix' as prefix. Returns ------- sc.AnnData Loaded anndata object. """ adata_input = settings.full_adata_input_prefix + path adata = sc.read_h5ad(filename=adata_input) logger.info(f"The adata object was loaded from: {adata_input}") return adata
[docs] @deco.log_anndata @beartype def save_h5ad(adata: sc.AnnData, path: str) -> None: """ Save an anndata object to an .h5ad file. Parameters ---------- adata : sc.AnnData Anndata object to save. path : str Name of the file to save the anndata object. NOTE: Uses the internal 'sctoolbox.settings.adata_output_dir' + 'sctoolbox.settings.adata_output_prefix' as prefix. """ # Save adata adata_output = settings.full_adata_output_prefix + path adata.write(filename=adata_output) logger.info(f"The adata object was saved to: {adata_output}")
[docs] @beartype def add_uns_info(adata: sc.AnnData, key: str | list[str], value: Any, how: str = "overwrite") -> None: """ Add information to adata.uns['sctoolbox']. This is used for logging the parameters and options of different steps in the analysis. Parameters ---------- adata : sc.AnnData An AnnData object. key : str | list[str] The key to add to adata.uns['sctoolbox']. If the key is a list, it represents a path within a nested dictionary. value : Any The value to add to adata.uns['sctoolbox']. how : str, default "overwrite" When set to "overwrite" provided key will be overwriten. If "append" will add element to existing list or dict. Raises ------ ValueError If value can not be appended. """ if "sctoolbox" not in adata.uns: adata.uns["sctoolbox"] = {} if isinstance(key, str): key = [key] d = adata.uns["sctoolbox"] for k in key[:-1]: # iterate over all keys except the last one if k not in d: d[k] = d.get(k, {}) d = d[k] # Add value to last key last_key = key[-1] if how == "overwrite": d[last_key] = value # last key contains value elif how == "append": if key[-1] not in d: d[last_key] = value # initialize with a value if key does not exist else: # append to existing key current_value = d[last_key] if isinstance(value, dict) and not isinstance(current_value, dict): nested = "adata.uns['sctoolbox'][" + "][".join(key) + "]" raise ValueError(f"Cannot append {value} to {nested} because it is not a dict.") elif type(current_value).__name__ == "ndarray": # convert numpy array to list in order to use "append"/extend" d[last_key] = list(current_value) else: d[last_key] = [current_value] # Append/extend/update value if isinstance(value, list): d[last_key].extend(value) elif isinstance(value, dict): d[last_key].update(value) # update dict else: d[last_key].append(value) # value is a single value # If list; remove duplicates and keep the last occurrence if isinstance(d[last_key], Sequence): d[last_key] = list(reversed(OrderedDict.fromkeys(reversed(d[last_key])))) # reverse list to keep last occurrence instead of first
[docs] @beartype def get_cell_values(adata: sc.AnnData, element: str) -> np.ndarray: """Get the values of a given element in adata.obs or adata.var per cell in adata. Can for example be used to extract gene expression values. Parameters ---------- adata : anndata.AnnData Anndata object. element : str The element to extract from adata.obs or adata.var, e.g. a column in adata.obs or an index in adata.var. Returns ------- np.ndarray Array of values per cell in adata. Raises ------ ValueError If element is not found in adata.obs or adata.var. """ if element in adata.obs: values = np.array(adata.obs[element].values) elif element in adata.var.index: idx = list(adata.var.index).index(element) values = adata.X[:, idx] values = values.todense().A1 if issparse(values) else values else: raise ValueError(f"Element '{element}' not found in adata.obs or adata.var.") return values
[docs] @beartype def prepare_for_cellxgene(adata: sc.AnnData, keep_obs: Optional[list[str]] = None, keep_var: Optional[list[str]] = None, rename_obs: Optional[dict[str, str]] = None, rename_var: Optional[dict[str, str]] = None, embedding_names: Optional[list[str]] = ["pca", "umap", "tsne"], cmap: Optional[str] = None, palette: Optional[str | Sequence[str]] = None, inplace: bool = False) -> Optional[sc.AnnData]: """ Prepare the given adata for cellxgene deployment. Parameters ---------- adata : sc.Anndata Anndata object. keep_obs : Optional[list[str]], default None adata.obs columns that should be kept. None to keep all. keep_var : Optional[list[str]], default None adata.var columns that should be kept. None to keep all. rename_obs : Optional[dict[str, str]], default None Dictionary of .obs columns to rename. Key is the old name, value the new one. rename_var : Optional[dict[str, str]], default None Dictionary of .var columns to rename. Key is the old name, value the new one. embedding_names : Optional[list[str]], default ["pca", "umap", "tsne"] List of embeddings to check for. Will raise an error if none of the embeddings are found. Set None to disable check. Embeddings are stored in `adata.obsm`. cmap : Optional[str], default None Color map to use for continous variables. Use this replacement color map for broken color maps. If None will use scanpy default, which uses `mpl.rcParams["image.cmap"]`. See `sc.pl.embedding`. palette : Optional[str | Sequence[str]], default None Color map to use for categorical annotation groups. Use this replacement color map for broken color maps. If None will use scanpy default, which uses `mpl.rcParams["axes.prop_cycle"]`. See `sc.pl.embedding`. inplace : bool, default False Raises ------ ValueError If not at least one of the named embeddings are found in the adata. Returns ------- Optional[sc.AnnData] Returns the deployment ready Anndata object. """ def clean_section(obj, axis="obs", keep=None, rename=None) -> None: """Clean either obs or var section of given adata object.""" if axis == "obs": sec_table = obj.obs elif axis == "var": sec_table = obj.var # drop columns if keep is not None: drop = set(sec_table.columns) - set(keep) sec_table.drop(columns=drop, inplace=True) # drop matching color maps for col in drop: if f"{col}_colors" in obj.uns.keys(): obj.uns.pop(f"{col}_colors") # rename columns if rename: sec_table.rename(columns=rename, inplace=True) # rename color maps for old, new in rename.items(): if f"{old}_colors" in obj.uns.keys(): obj.uns[f"{new}_colors"] = obj.uns.pop(f"{old}_colors") # convert Int32 to float64 columns for c in sec_table: if sec_table[c].dtype == 'Int32': sec_table[c] = sec_table[c].astype('float64') sec_table.index.names = ['index'] out = adata if inplace else adata.copy() # TODO remove more adata internals not needed for cellxgene # ----- .obsm ----- if embedding_names: if not any(f"X_{e}" == k for e in embedding_names for k in out.obsm.keys()): raise ValueError(f"Unable to find any of the embeddings {embedding_names}. At least one is needed for cellxgene.") # ----- .obs ----- clean_section(out, axis="obs", keep=keep_obs, rename=rename_obs) # ----- .var ----- clean_section(out, axis="var", keep=keep_var, rename=rename_var) # ----- .X ----- # convert .X to sparse matrix if needed if not scipy.sparse.isspmatrix(out.X): out.X = scipy.sparse.csc_matrix(out.X) out.X = out.X.astype("float32") # ----- .uns ----- for key in list(out.uns): # avoid RuntimeError by forcing a copy of dict keys. if key.endswith('colors'): obs_key = key.split("_colors")[0] # delete colors if they don't match a .obs column. if obs_key not in out.obs.columns: out.uns.pop(key) logger.warning(f"Deleted .uns[{key}] since it did not match a .obs column.") continue # fix colors not in 6-digit hex format # https://github.com/chanzuckerberg/cellxgene/issues/2598 out.uns[key] = np.array([(c if len(c) <= 7 else c[:-2]) for c in out.uns[key]]) # fix number of colors < number of categories if len(out.uns[key]) != len(set(out.obs[obs_key])): logger.warning(f"Coloring for adata.obs['{obs_key}'] broken. Reverting to {cmap if cmap else 'scanpy default'} color map.") # scanpy replaces broken colormap before plotting basis = list(out.obsm.keys())[0] sc.pl.embedding(adata=out, basis=basis, color=obs_key, palette=palette, color_map=cmap, show=False) plt.close() # prevent that plot is shown if not inplace: return out