Source code for sctoolbox.tools.norm_correct

"""Normalization and correction tools."""
import numpy as np
from scipy import sparse
import io
from contextlib import redirect_stderr
import copy
import multiprocessing as mp
import scanpy as sc
import scanpy.external as sce

from beartype.typing import Optional, Any, Union, Literal, Callable
from beartype import beartype

import sctoolbox.utils as utils
from sctoolbox.tools.dim_reduction import lsi
import sctoolbox.utils.decorator as deco
from sctoolbox._settings import settings
logger = settings.logger

batch_methods = Literal["bbknn",
                        "combat",
                        "mnn",
                        "harmony",
                        "scanorama"]

#####################################################################
# --------------------- Normalization methods --------------------- #
#####################################################################


[docs] def atac_norm(*args: Any, **kwargs: Any): """Normalize ATAC data - deprecated functionality. Use normalize_adata instead.""" logger.warning("The function 'atac_norm' is deprecated. Use 'normalize_adata' instead.") return normalize_adata(*args, **kwargs)
[docs] @deco.log_anndata @beartype def normalize_adata(adata: sc.AnnData, method: str | list[str], exclude_highly_expressed: bool = True, use_highly_variable: bool = False, target_sum: Optional[int] = None) -> dict[str, sc.AnnData]: """ Normalize the count matrix and calculate dimension reduction using different methods. Parameters ---------- adata : sc.AnnData Annotated data matrix. method : str | list[str] Normalization method. Either 'total' and/or 'tfidf'. - 'total': Performs normalization for total counts, log1p and PCA. - 'tfidf': Performs TFIDF normalization and LSI (corresponds to PCA). This method is often used for scATAC-seq data. exclude_highly_expressed : bool, default True Parameter for sc.pp.normalize_total. Decision to exclude highly expressed genes (HEG) from total normalization. use_highly_variable : bool, default False Parameter for sc.pp.pca and lsi. Decision to use highly variable genes for PCA/LSI. target_sum : Optional[int], default None Parameter for sc.pp.normalize_total. Decide the target sum of each cell after normalization. Returns ------- dict[str, sc.AnnData] Dictionary containing method name as key, and anndata as values. Each anndata is the annotated data matrix with normalized count matrix and PCA/LSI calculated. Raises ------ ValueError If method is not valid. Needs to be either 'total' or 'tfidf'. """ if isinstance(method, str): method = [method] adatas = {} for method_str in method: # method is a list adata = adata.copy() # make sure the original data is not modified if method_str == "total": # perform total normalization and pca logger.info('Performing total normalization and PCA...') sc.pp.normalize_total(adata, exclude_highly_expressed=exclude_highly_expressed, target_sum=target_sum) sc.pp.log1p(adata) sc.pp.pca(adata, use_highly_variable=use_highly_variable) elif method_str == "tfidf": logger.info('Performing TFIDF and LSI...') tfidf(adata) lsi(adata, use_highly_variable=use_highly_variable) # corresponds to PCA else: raise ValueError(f"Method '{method_str}' is invalid - must be either 'total' or 'tfidf'.") adatas[method_str] = adata return adatas
[docs] @beartype def tfidf(data: sc.AnnData, log_tf: bool = True, log_idf: bool = True, log_tfidf: bool = False, scale_factor: int = int(1e4)) -> None: """ Transform peak counts with TF-IDF (Term Frequency - Inverse Document Frequency). TF: peak counts are normalised by total number of counts per cell. DF: total number of counts for each peak. IDF: number of cells divided by DF. By default, log(TF) * log(IDF) is returned. Parameters ---------- data : sc.AnnData AnnData object with peak counts. log_tf : bool, default True Log-transform TF term if True. log_idf : bool, default True Log-transform IDF term if True. log_tfidf : bool, default Frue Log-transform TF*IDF term if True. Can only be used when log_tf and log_idf are False. scale_factor : int, default 1e4 Scale factor to multiply the TF-IDF matrix by. Notes ----- Function is from the muon package. Raises ------ AttributeError: log(TF*IDF) requires log(TF) and log(IDF) to be False. """ adata = data if log_tfidf and (log_tf or log_idf): raise AttributeError( "When returning log(TF*IDF), \ applying neither log(TF) nor log(IDF) is possible." ) if sparse.issparse(adata.X): n_peaks = np.asarray(adata.X.sum(axis=1)).reshape(-1) n_peaks = sparse.dia_matrix((1.0 / n_peaks, 0), shape=(n_peaks.size, n_peaks.size)) # This prevents making TF dense tf = np.dot(n_peaks, adata.X) else: n_peaks = np.asarray(adata.X.sum(axis=1)).reshape(-1, 1) tf = adata.X / n_peaks if scale_factor is not None and scale_factor != 0 and scale_factor != 1: tf = tf * scale_factor if log_tf: tf = np.log1p(tf) idf = np.asarray(adata.shape[0] / adata.X.sum(axis=0)).reshape(-1) if log_idf: idf = np.log1p(idf) if sparse.issparse(tf): idf = sparse.dia_matrix((idf, 0), shape=(idf.size, idf.size)) tf_idf = np.dot(tf, idf) else: tf_idf = np.dot(sparse.csr_matrix(tf), sparse.csr_matrix(np.diag(idf))) if log_tfidf: tf_idf = np.log1p(tf_idf) adata.X = np.nan_to_num(tf_idf, 0)
[docs] @beartype def tfidf_normalization(matrix: sparse.spmatrix, tf_type: Literal["raw", "term_frequency", "log"] = "term_frequency", idf_type: Literal["unary", "inverse_freq", "inverse_freq_smooth"] = "inverse_freq") -> sparse.csr_matrix: """ Perform TF-IDF normalization on a sparse matrix. The different variants of the term frequency and inverse document frequency are obtained from https://en.wikipedia.org/wiki/Tf-idf. Parameters ---------- matrix : scipy.sparse matrix The matrix to be normalized. tf_type : Literal["term_frequency", "log", "raw"], default "term_frequency" The type of term frequency to use. Can be either "raw", "term_frequency" or "log". idf_type : Literal["inverse_freq", "unary", "inverse_freq_smooth"], default "inverse_freq" The type of inverse document frequency to use. Can be either "unary", "inverse_freq" or "inverse_freq_smooth". Returns ------- sparse.csr_matrix tfidf normalized sparse matrix. Notes ----- This function requires a lot of memory. Another option is to use the ac.pp.tfidf of the muon package. """ # t - term (peak) # d - document (cell) # N - count of corpus (total set of cells) # Normalize matrix to number of found peaks dense = matrix.todense() peaks_per_cell = dense.sum(axis=1) # i.e. the length of the document(number of words) # Decide on which Term frequency to use: if tf_type == "raw": tf = dense elif tf_type == "term_frequency": tf = dense / peaks_per_cell # Counts normalized to peaks (words) per cell (document) elif tf_type == "log": tf = np.log1p(dense) # for binary documents, this scales with "raw" # Decide on the Inverse document frequency to use N = dense.shape[0] # number of cells (number of documents) df = dense.sum(axis=0) # number of cells carrying each peak (number of documents containing each word) if idf_type == "unary": idf = np.ones(dense.shape[1]) # shape is number of peaks elif idf_type == "inverse_freq": idf = np.log(N / df) # each cell has at least one peak (each document has one word), so df is always > 0 elif idf_type == "inverse_freq_smooth": idf = np.log(N / (df + 1)) + 1 # Obtain TF_IDF tf_idf = np.array(tf) * np.array(idf).squeeze() tf_idf = sparse.csr_matrix(tf_idf) return tf_idf
################################################################################### # --------------------------- Batch correction methods -------------------------- # ###################################################################################
[docs] @beartype def wrap_corrections(adata: sc.AnnData, batch_key: str, methods: Union[batch_methods, list[batch_methods], Callable] = ["bbknn", "mnn"], method_kwargs: dict = {}) -> dict[str, sc.AnnData]: """ Calculate multiple batch corrections for adata using the 'batch_correction' function. Parameters ---------- adata : sc.AnnData An annotated data matrix object to apply corrections to. batch_key : str The column in adata.obs containing batch information. methods : list[batch_methods] | Callable | batch_methods, default ["bbknn", "mnn"] The method(s) to use for batch correction. Options are: - bbknn - mnn - harmony - scanorama - combat Or provide a custom batch correction function. See `batch_correction(method)` for more information. method_kwargs : dict, default {} Dict with methods as keys. Values are dicts of additional parameters forwarded to method. See batch_correction(**kwargs). Returns ------- dict[str, sc.AnnData] Dictonary of batch corrected anndata objects. Where the key is the correction method and the value is the corrected anndata. Raises ------ ValueError If not all methods in methods are valid. """ # Ensure that methods can be looped over if isinstance(methods, str): methods = [methods] # check method_kwargs keys unknown_keys = set(method_kwargs.keys()) - set(methods) if unknown_keys: raise ValueError(f"Unknown methods in `method_kwargs` keys: {unknown_keys}") # Check the existance of packages before running batch_corrections required_packages = {"harmony": "harmonypy", "bbknn": "bbknn", "scanorama": "scanorama"} for method in methods: if method in required_packages: # not all packages need external tools f = io.StringIO() with redirect_stderr(f): # make the output of check_module silent; mnnpy prints ugly warnings utils.check_module(required_packages[method]) # Collect batch correction per method anndata_dict = {'uncorrected': adata} for method in methods: anndata_dict[method] = batch_correction(adata, batch_key, method, **method_kwargs.setdefault(method, {})) # batch correction returns the corrected adata logger.info("Finished batch correction(s)!") return anndata_dict
[docs] @deco.log_anndata @beartype def batch_correction(adata: sc.AnnData, batch_key: str, method: Union[batch_methods, list[batch_methods], Callable] = ["bbknn", "mnn"], highly_variable: bool = True, **kwargs: Any) -> sc.AnnData: """ Perform batch correction on the adata object using the 'method' given. Parameters ---------- adata : sc.AnnData An annotated data matrix object to apply corrections to. batch_key : str The column in adata.obs containing batch information. method : str or function Either one of the predefined methods or a custom function for batch correction. Note: The custom function is expected to accept an anndata object as the first parameter and return the batch corrected anndata. Available methods: - bbknn - mnn - harmony - scanorama - combat highly_variable : bool, default True Only for method 'mnn'. If True, only the highly variable genes (column 'highly_variable' in .var) will be used for batch correction. **kwargs : Any Additional arguments will be forwarded to the method function. Returns ------- sc.AnnData A copy of the anndata with applied batch correction. Raises ------ ValueError: 1. If batch_key column is not in adata.obs 2. If batch correction method is invalid. KeyError: If PCA has not been calculated before running bbknn. """ if not callable(method): method = method.lower() logger.info(f"Running batch correction with '{method}'...") # Check that batch_key is in adata object if batch_key not in adata.obs.columns: raise ValueError(f"The given batch_key '{batch_key}' is not in adata.obs.columns") # Run batch correction depending on method if method == "bbknn": import bbknn # sc.external.pp.bbknn() is broken due to n_trees / annoy_n_trees change # Get number of pcs in adata, as bbknn hardcodes n_pcs=50 try: n_pcs = adata.obsm["X_pca"].shape[1] except KeyError: raise KeyError("PCA has not been calculated. Please run sc.pp.pca() before running bbknn.") # Run bbknn adata = bbknn.bbknn(adata, batch_key=batch_key, n_pcs=n_pcs, copy=True, **kwargs) # bbknn is an alternative to neighbors elif method == "mnn": var_table = adata.var # var_table before batch correction # split adata on batch_key batch_categories = list(set(adata.obs[batch_key])) adatas = [adata[adata.obs[batch_key] == category] for category in batch_categories] # Set highly variable genes as var_subset if chosen (and available) var_subset = None if highly_variable: if "highly_variable" in adata.var.columns: var_subset = adata.var[adata.var.highly_variable].index # give individual adatas to mnn_correct corrected_adatas, _, _ = sce.pp.mnn_correct(adatas, batch_key=batch_key, var_subset=var_subset, batch_categories=batch_categories, do_concatenate=False, **kwargs) # Join corrected adatas corrected_adatas = corrected_adatas[0] # the output is a dict of list ([adata1, adata2, (...)], ) adata = sc.concat(corrected_adatas, join="outer", uns_merge="first") adata.var = var_table # add var table back into corrected adata sc.pp.scale(adata) # from the mnnpy github example sc.tl.pca(adata) # rerun pca sc.pp.neighbors(adata) elif method == "harmony": adata = adata.copy() # there is no copy option for harmony adata.obs[batch_key] = adata.obs[batch_key].astype("str") # harmony expects a batch key as string sce.pp.harmony_integrate(adata, key=batch_key, **kwargs) adata.obsm["X_pca"] = adata.obsm["X_pca_harmony"] sc.pp.neighbors(adata) elif method == "scanorama": adata = adata.copy() # there is no copy option for scanorama # scanorama expect the batch key in a sorted format # therefore anndata.obs should be sorted based on batch column before this method. original_order = adata.obs.index adata = adata[adata.obs[batch_key].argsort()] # sort the whole adata to make sure obs is the same order as matrix sce.pp.scanorama_integrate(adata, key=batch_key, **kwargs) adata.obsm["X_pca"] = adata.obsm["X_scanorama"] sc.pp.neighbors(adata) # sort the adata back to the original order adata = adata[original_order] elif method == "combat": corrected_mat = sc.pp.combat(adata, key=batch_key, inplace=False, **kwargs) adata = adata.copy() # make sure adata is not modified adata.X = sparse.csr_matrix(corrected_mat) sc.pp.pca(adata) sc.pp.neighbors(adata) elif callable(method): adata = method(adata.copy(), **kwargs) else: raise ValueError(f"Method '{method}' is not a valid batch correction method.") return adata # the corrected adata object
[docs] @deco.log_anndata @beartype def evaluate_batch_effect(adata: sc.AnnData, batch_key: str, obsm_key: str = 'X_umap', col_name: str = 'LISI_score', max_dims: int = 5, perplexity: int = 30, inplace: bool = False) -> Optional[sc.AnnData]: """ Evaluate batch effect methods using LISI. Parameters ---------- adata : sc.AnnData Anndata object with PCA and umap/tsne for batch evaluation. batch_key : str The column in adata.obs containing batch information. obsm_key : str, default 'X_umap' The column in adata.obsm containing coordinates. col_name : str, default 'LISI_score' Column name for storing the LISI score in .obs. max_dims : int, default 5 Maximum number of dimensions of adata.obsm matrix to use for LISI (to speed up computation). perplexity : int, default 30 Perplexity for the LISI score calculation. inplace : bool, default False Whether to work inplace on the anndata object. Returns ------- Optional[sc.AnnData] if inplace is True, LISI_score is added to adata.obs inplace (returns None), otherwise a copy of the adata is returned. Notes ----- - LISI score is calculated for each cell and it is between 1-n for a data-frame with n categorical variables. - indicates the effective number of different categories represented in the local neighborhood of each cell. - If the cells are well-mixed, then we expect the LISI score to be near n for a data with n batches. - The higher the LISI score is, the better batch correction method worked to normalize the batch effect and mix the cells from different batches. - For further information on LISI: https://genomebiology.biomedcentral.com/articles/10.1186/s13059-019-1850-9 Raises ------ KeyError: 1. If obsm_key is not in adata.obsm. 2. If batch_key is no column in adata.obs. """ # Load LISI utils.check_module("harmonypy") from harmonypy.lisi import compute_lisi # Handle inplace option adata_m = adata if inplace else adata.copy() # checks if obsm_key not in adata_m.obsm: raise KeyError(f"adata.obsm does not contain the obsm key: {obsm_key}") if batch_key not in adata_m.obs: raise KeyError(f"adata.obs does not contain the batch key: {batch_key}") # run LISI on all adata objects obsm_matrix = adata_m.obsm[obsm_key][:, :max_dims] lisi_res = compute_lisi(obsm_matrix, adata_m.obs, [batch_key], perplexity=perplexity) adata_m.obs[col_name] = lisi_res.flatten() if not inplace: return adata_m
[docs] @beartype def wrap_batch_evaluation(adatas: dict[str, sc.AnnData], batch_key: str, obsm_keys: str | list[str] = ['X_pca', 'X_umap'], threads: int = 1, max_dims: int = 5, inplace: bool = False) -> Optional[dict[str, sc.AnnData]]: """ Evaluate batch correction methods for a dict of anndata objects (using LISI score calculation). Parameters ---------- adatas : dict[str, sc.AnnData] Dict containing an anndata object for each batch correction method as values. Keys are the name of the respective method. E.g.: {"bbknn": anndata} batch_key : str The column in adata.obs containing batch information. obsm_keys : str | list[str], default ['X_pca', 'X_umap'] Key(s) to coordinates on which the score is calculated. threads : int, default 1 Number of threads to use for parallelization. max_dims : int, default 5 Maximum number of dimensions of adata.obsm matrix to use for LISI (to speed up computation). inplace : bool, default False Whether to work inplace on the anndata dict. Returns ------- Optional[dict[str, sc.AnnData]] Dict containing an anndata object for each batch correction method as values of LISI scores added to .obs. """ if utils._is_notebook() is True: from tqdm import tqdm_notebook as tqdm else: from tqdm import tqdm # Handle inplace option adatas_m = adatas if inplace else copy.deepcopy(adatas) # Ensure that obsm_key can be looped over if isinstance(obsm_keys, str): obsm_keys = [obsm_keys] # Evaluate batch effect for every adata if threads == 1: pbar = tqdm(total=len(adatas_m) * len(obsm_keys), desc="Calculation progress ") for adata in adatas_m.values(): n_cells = adata.shape[0] perplexity = min(30, int(n_cells / 3)) # adjust perplexity for small datasets for obsm in obsm_keys: evaluate_batch_effect(adata, batch_key, col_name=f"LISI_score_{obsm}", obsm_key=obsm, max_dims=max_dims, perplexity=perplexity, inplace=True) pbar.update() else: utils.check_module("harmonypy") from harmonypy.lisi import compute_lisi pool = mp.Pool(threads) jobs = {} for i, adata in enumerate(adatas_m.values()): n_cells = adata.shape[0] perplexity = min(30, int(n_cells / 3)) # adjust perplexity for small datasets for obsm_key in obsm_keys: obsm_matrix = adata.obsm[obsm_key][:, :max_dims] obs_mat = adata.obs[[batch_key]] job = pool.apply_async(compute_lisi, args=(obsm_matrix, obs_mat, [batch_key], perplexity,)) jobs[(i, obsm_key)] = job pool.close() # Monitor all jobs with a pbar utils.monitor_jobs(jobs, "Calculating LISI scores") # waits for all jobs to finish pool.join() # Assign results to adata for adata_i, obsm_key in jobs: adata = list(adatas_m.values())[adata_i] adata.obs[f"LISI_score_{obsm_key}"] = jobs[(adata_i, obsm_key)].get().flatten() if not inplace: return adatas_m