Source code for sctoolbox.tools.receptor_ligand

"""Tools for a receptor-ligand analysis."""
import pandas as pd
from collections import Counter
import scipy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import igraph as ig
from itertools import combinations_with_replacement
from matplotlib import cm
from tqdm import tqdm
import matplotlib
from matplotlib.patches import ConnectionPatch
import matplotlib.lines as lines
from sklearn.preprocessing import minmax_scale
import warnings
import scanpy as sc

from beartype.typing import Optional, Tuple
import numpy.typing as npt
from beartype import beartype

import sctoolbox.utils.decorator as deco
from sctoolbox._settings import settings


# -------------------------------------------------- setup functions -------------------------------------------------- #


[docs] @deco.log_anndata @beartype def download_db(adata: sc.AnnData, db_path: str, ligand_column: str, receptor_column: str, sep: str = "\t", inplace: bool = False, overwrite: bool = False) -> Optional[sc.AnnData]: r""" Download table of receptor-ligand interactions and store in adata. Parameters ---------- adata : sc.AnnData Analysis object the database will be added to. db_path : str Path to database table. A valid database needs a column with receptor gene ids/symbols and ligand gene ids/symbols. Human: http://tcm.zju.edu.cn/celltalkdb/download/processed_data/human_lr_pair.txt Mouse: http://tcm.zju.edu.cn/celltalkdb/download/processed_data/mouse_lr_pair.txt ligand_column : str Name of the column with ligand gene names. Use 'ligand_gene_symbol' for the urls provided above. receptor_column : str Name of column with receptor gene names. Use 'receptor_gene_symbol' for the urls provided above. sep : str, default '\t' Separator of database table. inplace : bool, default False Whether to copy `adata` or modify it inplace. overwrite : bool, default False If True will overwrite existing database. Notes ----- This will remove all information stored in adata.uns['receptor-ligand'] Returns ------- Optional[sc.AnnData] If not inplace, return copy of adata with added database path and database table to adata.uns['receptor-ligand'] Raises ------ ValueError: 1: If ligand_column is not in database. 2: If receptor_column is not in database. """ # datbase already existing? if not overwrite and "receptor-ligand" in adata.uns and "database" in adata.uns["receptor-ligand"]: warnings.warn("Database already exists! Skipping. Set `overwrite=True` to replace.") if inplace: return else: return adata database = pd.read_csv(db_path, sep=sep) # check column names in table if ligand_column not in database.columns: raise ValueError(f"Ligand column '{ligand_column}' not found in database! Available columns: {database.columns}") if receptor_column not in database.columns: raise ValueError(f"Receptor column '{receptor_column}' not found in database! Available columns: {database.columns}") modified_adata = adata if inplace else adata.copy() # setup dict to store information old data will be overwriten! modified_adata.uns['receptor-ligand'] = dict() modified_adata.uns['receptor-ligand']['database_path'] = db_path modified_adata.uns['receptor-ligand']['database'] = database modified_adata.uns['receptor-ligand']['ligand_column'] = ligand_column modified_adata.uns['receptor-ligand']['receptor_column'] = receptor_column if not inplace: return modified_adata
[docs] @deco.log_anndata @beartype def calculate_interaction_table(adata: sc.AnnData, cluster_column: str, gene_index: Optional[str] = None, normalize: int = 1000, inplace: bool = False, overwrite: bool = False) -> Optional[sc.AnnData]: """ Calculate an interaction table of the clusters defined in adata. Parameters ---------- adata : sc.AnnData AnnData object that holds the expression values and clustering cluster_column : str Name of the cluster column in adata.obs. gene_index : Optional[str], default None Column in adata.var that holds gene symbols/ ids. Corresponds to `download_db(ligand_column, receptor_column)`. Uses index when None. normalize : int, default 1000 Correct clusters to given size. inplace : bool, default False Whether to copy `adata` or modify it inplace. overwrite : bool, default False If True will overwrite existing interaction table. Returns ------- Optional[sc.AnnData] If not inpalce, return copy of adata with added interactions table to adata.uns['receptor-ligand']['interactions'] Raises ------ ValueError: 1: If receptor-ligand database cannot be found. 2: Id database genes do not match adata genes. Exception: If not interactions were found. """ if "receptor-ligand" not in adata.uns.keys(): raise ValueError("Could not find receptor-ligand database. Please setup database with `download_db(...)` before running this function.") # interaction table already exists? if not overwrite and "receptor-ligand" in adata.uns and "interactions" in adata.uns["receptor-ligand"]: warnings.warn("Interaction table already exists! Skipping. Set `overwrite=True` to replace.") if inplace: return else: return adata r_col, l_col = adata.uns["receptor-ligand"]["receptor_column"], adata.uns["receptor-ligand"]["ligand_column"] index = adata.var[gene_index] if gene_index else adata.var.index # test if database gene columns overlap with adata.var genes if (not set(adata.uns["receptor-ligand"]["database"][r_col]) & set(index) or not set(adata.uns["receptor-ligand"]["database"][l_col]) & set(index)): raise ValueError(f"Database columns '{r_col}', '{l_col}' don't match adata.var['{gene_index}']. Please make sure to select gene ids or symbols in all columns.") # ----- compute cluster means and expression percentage for each gene ----- # gene mean expression per cluster cl_mean_expression = pd.DataFrame(index=index) # percent cells in cluster expressing gene cl_percent_expression = pd.DataFrame(index=index) # number of cells for each cluster clust_sizes = {} # fill above tables for cluster in tqdm(set(adata.obs[cluster_column]), desc="computing cluster gene scores"): # filter adata to a specific cluster cluster_adata = adata[adata.obs[cluster_column] == cluster] clust_sizes[cluster] = len(cluster_adata) # -- compute cluster means -- if gene_index is None: cl_mean_expression.loc[cl_mean_expression.index.isin(cluster_adata.var.index), cluster] = cluster_adata.X.mean(axis=0).reshape(-1, 1) else: cl_mean_expression.loc[cl_mean_expression.index.isin(cluster_adata.var[gene_index]), cluster] = cluster_adata.X.mean(axis=0).reshape(-1, 1) # -- compute expression percentage -- # get nonzero expression count for all genes _, cols = cluster_adata.X.nonzero() gene_occurence = Counter(cols) cl_percent_expression[cluster] = 0 cl_percent_expression.iloc[list(gene_occurence.keys()), cl_percent_expression.columns.get_loc(cluster)] = list(gene_occurence.values()) cl_percent_expression[cluster] = cl_percent_expression[cluster] / len(cluster_adata.obs) * 100 # combine duplicated genes through mean (can happen due to mapping between organisms) if len(set(cl_mean_expression.index)) != len(cl_mean_expression): cl_mean_expression = cl_mean_expression.groupby(cl_mean_expression.index).mean() cl_percent_expression = cl_percent_expression.groupby(cl_percent_expression.index).mean() # cluster scaling factor for cluster size correction scaling_factor = {k: v / normalize for k, v in clust_sizes.items()} # ----- compute zscore of cluster means for each gene ----- # create pandas functions that show progress bar tqdm.pandas(desc="computing Z-scores") zscores = cl_mean_expression.progress_apply(lambda x: pd.Series(scipy.stats.zscore(x, nan_policy='omit'), index=cl_mean_expression.columns), axis=1) interactions = {"receptor_cluster": [], "ligand_cluster": [], "receptor_gene": [], "ligand_gene": [], "receptor_score": [], "ligand_score": [], "receptor_percent": [], "ligand_percent": [], "receptor_scale_factor": [], "ligand_scale_factor": [], "receptor_cluster_size": [], "ligand_cluster_size": []} # ----- create interaction table ----- for _, (receptor, ligand) in tqdm(adata.uns["receptor-ligand"]["database"][[r_col, l_col]].iterrows(), total=len(adata.uns["receptor-ligand"]["database"]), desc="finding receptor-ligand interactions"): # skip interaction if not in data if receptor is np.nan or ligand is np.nan: continue if receptor not in zscores.index or ligand not in zscores.index: continue # add interactions to dict for receptor_cluster in zscores.columns: for ligand_cluster in zscores.columns: interactions["receptor_gene"].append(receptor) interactions["ligand_gene"].append(ligand) interactions["receptor_cluster"].append(receptor_cluster) interactions["ligand_cluster"].append(ligand_cluster) interactions["receptor_score"].append(zscores.loc[receptor, receptor_cluster]) interactions["ligand_score"].append(zscores.loc[ligand, ligand_cluster]) interactions["receptor_percent"].append(cl_percent_expression.loc[receptor, receptor_cluster]) interactions["ligand_percent"].append(cl_percent_expression.loc[ligand, ligand_cluster]) interactions["receptor_scale_factor"].append(scaling_factor[receptor_cluster]) interactions["ligand_scale_factor"].append(scaling_factor[ligand_cluster]) interactions["receptor_cluster_size"].append(clust_sizes[receptor_cluster]) interactions["ligand_cluster_size"].append(clust_sizes[ligand_cluster]) interactions = pd.DataFrame(interactions) # compute interaction score interactions["receptor_score"] = interactions["receptor_score"] * interactions["receptor_scale_factor"] interactions["ligand_score"] = interactions["ligand_score"] * interactions["ligand_scale_factor"] interactions["interaction_score"] = interactions["receptor_score"] + interactions["ligand_score"] # clean up columns interactions.drop(columns=["receptor_scale_factor", "ligand_scale_factor"], inplace=True) # no interactions found error if not len(interactions): raise Exception("Failed to find any receptor-ligand interactions. Consider using a different database.") # add to adata modified_adata = adata if inplace else adata.copy() modified_adata.uns['receptor-ligand']['interactions'] = interactions if not inplace: return modified_adata
# -------------------------------------------------- plotting functions -------------------------------------------------- #
[docs] @deco.log_anndata @beartype def interaction_violin_plot(adata: sc.AnnData, min_perc: int | float, save: Optional[str] = None, figsize: Tuple[int, int] = (5, 20), dpi: int = 100) -> npt.ArrayLike: """ Generate violin plot of pairwise cluster interactions. Parameters ---------- adata : sc.AnnData AnnData object min_perc : int | float Minimum percentage of cells in a cluster that express the respective gene. A value from 0-100. save : str, default None Output filename. Uses the internal 'sctoolbox.settings.figure_dir'. figsize : int tuple, default (5, 20) Figure size dpi : float, default 100 The resolution of the figure in dots-per-inch. Returns ------- npt.ArrayLike Object containing all plots. As returned by matplotlib.pyplot.subplots """ # check if data is available _check_interactions(adata) interactions = get_interactions(adata) rows = len(set(interactions["receptor_cluster"])) fig, axs = plt.subplots(ncols=1, nrows=rows, figsize=figsize, dpi=dpi, tight_layout={'rect': (0, 0, 1, 0.95)}) # prevent label clipping; leave space for title flat_axs = axs.flatten() # generate violins of one cluster vs rest in each iteration for i, cluster in enumerate(sorted(set(interactions["receptor_cluster"].tolist() + interactions["ligand_cluster"].tolist()))): cluster_interactions = get_interactions(adata, min_perc=min_perc, group_a=[cluster]) # get column of not main clusters cluster_interactions["Cluster"] = cluster_interactions.apply(lambda x: x[1] if x[0] == cluster else x[0], axis=1).tolist() plot = sns.violinplot(x=cluster_interactions["Cluster"], y=cluster_interactions["interaction_score"], ax=flat_axs[i]) plot.set_xticklabels(plot.get_xticklabels(), rotation=90) flat_axs[i].set_title(f"Cluster {cluster}") # save plot if save: fig.savefig(f"{settings.figure_dir}/{save}") return axs
[docs] @deco.log_anndata @beartype def hairball(adata: sc.AnnData, min_perc: int | float, interaction_score: float | int = 0, interaction_perc: Optional[int | float] = None, save: Optional[str] = None, title: Optional[str] = "Network", color_min: float | int = 0, color_max: Optional[float | int] = None, cbar_label: str = "Interaction count", show_count: bool = False, restrict_to: Optional[list[str]] = None, additional_nodes: Optional[list[str]] = None, hide_edges: Optional[list[Tuple[str, str]]] = None) -> npt.ArrayLike: """ Generate network graph of interactions between clusters. Parameters ---------- adata : sc.AnnData AnnData object min_perc : int | float Minimum percentage of cells in a cluster that express the respective gene. A value from 0-100. interaction_score : float | int, default 0 Interaction score must be above this threshold for the interaction to be counted in the graph. interaction_perc : Optional[int | float], default None Select interaction scores above or equal to the given percentile. Will overwrite parameter interaction_score. A value from 0-100. save : str, default None Output filename. Uses the internal 'sctoolbox.settings.figure_dir'. title : str, default 'Network' The plots title. color_min : float, default 0 Min value for color range. color_max : Optional[float | int], default None Max value for color range. cbar_label : str, default 'Interaction count' Label above the colorbar. show_count : bool, default False Show the interaction count in the hairball. restrict_to : Optional[list[str]], default None Only show given clusters provided in list. additional_nodes : Optional[list[str]], default None List of additional node names displayed in the hairball. hide_edges : Optional[list[Tuple[str, str]]], default None List of tuples with node names that should not have an edge shown. Order doesn't matter. E.g. `[("a", "b")]` to omit the edge between node a and b. Returns ------- npt.ArrayLike Object containing all plots. As returned by matplotlib.pyplot.subplots Raises ------ ValueError: If restrict_to contains invalid clusters. """ # check if data is available _check_interactions(adata) interactions = get_interactions(adata) # any invalid cluster names if restrict_to: valid_clusters = set.union(set(interactions["ligand_cluster"]), set(interactions["receptor_cluster"]), set(additional_nodes) if additional_nodes else set()) invalid_clusters = set(restrict_to) - valid_clusters if invalid_clusters: raise ValueError(f"Invalid cluster in `restrict_to`: {invalid_clusters}") # ----- create igraph ----- graph = ig.Graph() # --- set nodes --- if restrict_to: clusters = restrict_to else: clusters = list(set(list(interactions["receptor_cluster"]) + list(interactions["ligand_cluster"]))) # add additional nodes if additional_nodes: clusters += additional_nodes graph.add_vertices(clusters) graph.vs['label'] = clusters graph.vs['size'] = 0.1 # node size graph.vs['label_size'] = 12 # label size graph.vs['label_dist'] = 2 # distance of label to node # not working graph.vs['label_angle'] = 1.5708 # rad = 90 degree # not working # --- set edges --- for (a, b) in combinations_with_replacement(clusters, 2): if hide_edges and ((a, b) in hide_edges or (b, a) in hide_edges): continue subset = get_interactions(adata, min_perc=min_perc, interaction_score=interaction_score, interaction_perc=interaction_perc, group_a=[a], group_b=[b]) graph.add_edge(a, b, weight=len(subset)) # set edge colors/ width based on weight colormap = cm.get_cmap('viridis', len(graph.es)) print(f"Max weight {np.max(np.array(graph.es['weight']))}") max_weight = np.max(np.array(graph.es['weight'])) if color_max is None else color_max for e in graph.es: e["color"] = colormap(e["weight"] / max_weight, e["weight"] / max_weight) e["width"] = (e["weight"] / max_weight) # * 10 # show weights in plot if show_count and e["weight"] > 0: e["label"] = e["weight"] e["label_size"] = 25 # ----- plotting ----- # Create the figure fig, axes = plt.subplots(1, 2, figsize=(8, 6), gridspec_kw={'width_ratios': [20, 1]}) fig.suptitle(title, fontsize=12) ig.plot(obj=graph, layout=graph.layout_circle(order=sorted(clusters)), target=axes[0]) # add colorbar cb = matplotlib.colorbar.ColorbarBase(axes[1], orientation='vertical', cmap=colormap, norm=matplotlib.colors.Normalize(0 if color_min is None else color_min, max_weight) ) cb.ax.tick_params(labelsize=10) cb.ax.set_title(cbar_label, fontsize=10) # prevent label clipping out of picture plt.tight_layout() plt.subplots_adjust(right=0.9) if save: fig.savefig(f"{settings.figure_dir}/{save}") return axes
[docs] @beartype def progress_violins(datalist: list[pd.DataFrame], datalabel: list[str], cluster_a: str, cluster_b: str, min_perc: float | int, save: str, figsize: Tuple[int | float, int | float] = (12, 6)) -> str: """ Show cluster interactions over timepoints. CURRENTLY NOT FUNCTIONAL! TODO Implement function Parameters ---------- datalist : list[pd.DataFrame] List of interaction DataFrames. Each DataFrame represents a timepoint. datalabel : list[str] List of strings. Used to label the violins. cluster_a : str Name of the first interacting cluster. cluster_b : str Name of the second interacting cluster. min_perc : float | int Minimum percentage of cells in a cluster each gene must be expressed in. save : str Path to output file. figsize : Tuple[int, int], default (12, 6) Tuple of plot (width, height). Returns ------- str """ return "Function to be implemented" fig, axs = plt.subplots(1, len(datalist), figsize=figsize) fig.suptitle(f"{cluster_a} - {cluster_b}") flat_axs = axs.flatten() for i, (table, label) in enumerate(zip(datalist, datalabel)): # filter data subset = table[((table["cluster_a"] == cluster_a) & (table["cluster_b"] == cluster_b) | (table["cluster_a"] == cluster_b) & (table["cluster_b"] == cluster_a)) & (table["percentage_a"] >= min_perc) & (table["percentage_b"] >= min_perc)] v = sns.violinplot(data=subset, y="interaction_score", ax=flat_axs[i]) v.set_xticklabels([label]) plt.tight_layout() if save is not None: fig.savefig(save)
[docs] @beartype def interaction_progress(datalist: list[sc.AnnData], datalabel: list[str], receptor: str, ligand: str, receptor_cluster: str, ligand_cluster: str, figsize: Tuple[int | float, int | float] = (4, 4), dpi: int = 100, save: Optional[str] = None) -> matplotlib.axes.Axes: """ Barplot that shows the interaction score of a single interaction between two given clusters over multiple datasets. TODO add checks & error messages Parameters ---------- datalist : list[sc.AnnData] List of anndata objects. datalabel : list[str] List of labels for the given datalist. receptor : str Name of the receptor gene. ligand : str Name of the ligand gene. receptor_cluster : str Name of the receptor cluster. ligand_cluster : str Name of the ligand cluster. figsize : Tuple[int | float, int | float], default (4, 4) Figure size in inch. dpi : int, default 100 Dots per inch. save : Optional[str], default None Output filename. Uses the internal 'sctoolbox.settings.figure_dir'. Returns ------- matplotlib.axes.Axes The plotting object. """ table = [] for data, label in zip(datalist, datalabel): # interactions inter = data.uns["receptor-ligand"]["interactions"] # select interaction inter = inter[ (inter["receptor_cluster"] == receptor_cluster) & (inter["ligand_cluster"] == ligand_cluster) & (inter["receptor_gene"] == receptor) & (inter["ligand_gene"] == ligand) ].copy() # add datalabel inter["name"] = label table.append(inter) table = pd.concat(table) # plot with plt.rc_context({"figure.figsize": figsize, "figure.dpi": dpi}): plot = sns.barplot( data=table, x="name", y="interaction_score" ) plot.set( title=f"{receptor} - {ligand}\n{receptor_cluster} - {ligand_cluster}", ylabel="Interaction Score", xlabel="" ) plot.set_xticklabels( plot.get_xticklabels(), rotation=90, horizontalalignment='right' ) plt.tight_layout() if save: plt.savefig(f"{settings.figure_dir}/{save}") return plot
[docs] @deco.log_anndata @beartype def connectionPlot(adata: sc.AnnData, restrict_to: Optional[list[str]] = None, figsize: Tuple[int | float, int | float] = (10, 15), dpi: int = 100, connection_alpha: Optional[str] = "interaction_score", save: Optional[str] = None, title: Optional[str] = None, # receptor params receptor_cluster_col: str = "receptor_cluster", receptor_col: str = "receptor_gene", receptor_hue: str = "receptor_score", receptor_size: str = "receptor_percent", receptor_genes: Optional[list[str]] = None, # ligand params ligand_cluster_col: str = "ligand_cluster", ligand_col: str = "ligand_gene", ligand_hue: str = "ligand_score", ligand_size: str = "ligand_percent", ligand_genes: Optional[list[str]] = None, filter: Optional[str] = None, lw_multiplier: int | float = 2, wspace: float = 0.4, line_colors: Optional[str] = "rainbow") -> npt.ArrayLike: """ Show specific receptor-ligand connections between clusters. Parameters ---------- adata : sc.AnnData AnnData object restrict_to : Optional[list[str]], default None Restrict plot to given cluster names. figsize : Tuple[int | float, int | float], default (10, 15) Figure size dpi : float, default 100 The resolution of the figure in dots-per-inch. connection_alpha : str, default 'interaction_score' Name of column that sets alpha value of lines between plots. None to disable. save : Optional[str], default None Output filename. Uses the internal 'sctoolbox.settings.figure_dir'. title : Optional[str], default None Title of the plot receptor_cluster_col : str, default 'receptor_cluster' Name of column containing cluster names of receptors. Shown on x-axis. receptor_col : str, default 'receptor_gene' Name of column containing gene names of receptors. Shown on y-axis. receptor_hue : str, default 'receptor_score' Name of column containing receptor scores. Shown as point color. receptor_size : str, default 'receptor_percent' Name of column containing receptor expression percentage. Shown as point size. receptor_genes : Optional[list[str]], default None Restrict receptors to given genes. ligand_cluster_col : str, default 'ligand_cluster' Name of column containing cluster names of ligands. Shown on x-axis. ligand_col : str, default 'ligand_gene' Name of column containing gene names of ligands. Shown on y-axis. ligand_hue : str, default 'ligand_score' Name of column containing ligand scores. Shown as point color. ligand_size : str, default 'ligand_percent' Name of column containing ligand expression percentage. Shown as point size. ligand_genes : Optional[list[str]], default None Restrict ligands to given genes. filter : Optional[str], default None Conditions to filter the interaction table on. E.g. 'column_name > 5 & other_column < 2'. Forwarded to pandas.DataFrame.query. lw_multiplier : int | float, default 2 Linewidth multiplier. wspace : float, default 0.4 Width between plots. Fraction of total width. line_colors : Optional[str], default 'rainbow' Name of colormap used to color lines. All lines are black if None. Returns ------- npt.ArrayLike Object containing all plots. As returned by matplotlib.pyplot.subplots Raises ------ Exception: If no onteractions between clsuters are found. """ # check if data is available _check_interactions(adata) data = get_interactions(adata).copy() # filter receptor genes if receptor_genes: data = data[data[receptor_col].isin(receptor_genes)] # filter ligand genes if ligand_genes: data = data[data[ligand_col].isin(ligand_genes)] # filter interactions if filter: data.query(filter, inplace=True) # restrict interactions to certain clusters if restrict_to: data = data[data[receptor_cluster_col].isin(restrict_to) & data[ligand_cluster_col].isin(restrict_to)] if len(data) < 1: raise Exception(f"No interactions between clusters {restrict_to}") # setup subplot fig, axs = plt.subplots(1, 2, figsize=figsize, dpi=dpi, gridspec_kw={'wspace': wspace}) fig.suptitle(title) # receptor plot r_plot = sns.scatterplot(data=data, y=receptor_col, x=receptor_cluster_col, hue=receptor_hue, size=receptor_size, ax=axs[0]) r_plot.set(xlabel="Cluster", ylabel=None, title="Receptor", axisbelow=True) axs[0].tick_params(axis='x', rotation=90) axs[0].grid(alpha=0.8) # ligand plot l_plot = sns.scatterplot(data=data, y=ligand_col, x=ligand_cluster_col, hue=ligand_hue, size=ligand_size, ax=axs[1]) axs[1].yaxis.tick_right() l_plot.set(xlabel="Cluster", ylabel=None, title="Ligand", axisbelow=True) axs[1].tick_params(axis='x', rotation=90) axs[1].grid(alpha=0.8) # force tick labels to be populated # https://stackoverflow.com/questions/41122923/getting-empty-tick-labels-before-showing-a-plot-in-matplotlib fig.canvas.draw() # add receptor-ligand lines receptors = list(set(data[receptor_col])) # create colorramp if line_colors: cmap = cm.get_cmap(line_colors, len(receptors)) colors = cmap(range(len(receptors))) else: colors = ["black"] * len(receptors) # scale connection score column between 0-1 to be used as alpha values if connection_alpha: # note: minmax_scale sometimes produces values >1. Looks like a rounding error (1.000000000002). data["alpha"] = minmax_scale(data[connection_alpha], feature_range=(0, 1)) # fix values >1 data.loc[data["alpha"] > 1, "alpha"] = 1 else: data["alpha"] = 1 # find receptor label location for i, label in enumerate(axs[0].get_yticklabels()): data.loc[data[receptor_col] == label.get_text(), "rec_index"] = i # find ligand label location for i, label in enumerate(axs[1].get_yticklabels()): data.loc[data[ligand_col] == label.get_text(), "lig_index"] = i # add receptor-ligand lines # draws strongest connection for each pair for rec, color in zip(receptors, colors): pairs = data.loc[data[receptor_col] == rec] for lig in set(pairs[ligand_col]): # get all connections for current pair connections = pairs.loc[pairs[ligand_col] == lig] # get max connection max_con = connections.loc[connections["alpha"].idxmax()] # stolen from https://matplotlib.org/stable/gallery/userdemo/connect_simple01.html # Draw a line between the different points, defined in different coordinate # systems. con = ConnectionPatch( # x in axes coordinates, y in data coordinates xyA=(1, max_con["rec_index"]), coordsA=axs[0].get_yaxis_transform(), # x in axes coordinates, y in data coordinates xyB=(0, max_con["lig_index"]), coordsB=axs[1].get_yaxis_transform(), arrowstyle="-", color=color, zorder=-1000, alpha=max_con["alpha"], linewidth=max_con["alpha"] * lw_multiplier ) axs[1].add_artist(con) # ----- legends ----- # set receptor plot legend position sns.move_legend(r_plot, loc='upper right', bbox_to_anchor=(-1, 1, 0, 0)) # create legend for connection lines if connection_alpha: step_num = 5 s_steps, a_steps = np.linspace(min(data[connection_alpha]), max(data[connection_alpha]), step_num), np.linspace(0, 1, step_num) # create proxy actors https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html#proxy-legend-handles line_list = [lines.Line2D([], [], color="black", alpha=a, linewidth=a * lw_multiplier, label=f"{np.round(s, 2)}") for a, s in zip(a_steps, s_steps)] line_list.insert(0, lines.Line2D([], [], alpha=0, label=connection_alpha)) # add to current legend handles, _ = axs[1].get_legend_handles_labels() axs[1].legend(handles=handles + line_list, bbox_to_anchor=(2, 1, 0, 0), loc='upper left') else: # set ligand plot legend position axs[1].legend(bbox_to_anchor=(2, 1, 0, 0), loc='upper left') if save: plt.savefig(f"{settings.figure_dir}/{save}", bbox_inches='tight') return axs
# -------------------------------------------------- helper functions -------------------------------------------------- #
[docs] @deco.log_anndata @beartype def get_interactions(anndata: sc.AnnData, min_perc: Optional[float | int] = None, interaction_score: Optional[float | int] = None, interaction_perc: Optional[float | int] = None, group_a: Optional[list[str]] = None, group_b: Optional[list[str]] = None, save: Optional[str] = None) -> pd.DataFrame: """ Get interaction table from anndata and apply filters. Parameters ---------- anndata : sc.AnnData Anndata object to pull interaction table from. min_perc : Optional[float | int], default None Minimum percent of cells in a cluster that express the ligand/ receptor gene. Value from 0-100. interaction_score : Optional[float | int], default None Filter receptor-ligand interactions below given score. Ignored if `interaction_perc` is set. interaction_perc : Optional[float | int], default None Filter receptor-ligand interactions below the given percentile. Overwrite `interaction_score`. Value from 0-100. group_a : Optional[list[str]], default None List of cluster names that must be present in any given receptor-ligand interaction. group_b : Optional[list[str]], default None List of cluster names that must be present in any given receptor-ligand interaction. save : Optional[str], default None Output filename. Uses the internal 'sctoolbox.settings.table_dir'. Returns ------- pd.DataFrame Table that contains interactions. Columns: - receptor_cluster = name of the receptor cluster - ligand_cluster = name of the ligand cluster - receptor_gene = name of the receptor gene - ligand_gene = name of the ligand gene - receptor_score = zscore of receptor gene cluster mean expression (scaled by cluster size) - ligand_score = zscore of ligand gene cluster mean expression (scaled by cluster size) - receptor_percent = percent of cells in cluster expressing receptor gene - ligand_percent = percent of cells in cluster expressing ligand gene - receptor_cluster_size = number of cells in receptor cluster - ligand_cluster_size = number of cells in ligand cluster - interaction_score = sum of receptor_score and ligand_score """ # check if data is available _check_interactions(anndata) table = anndata.uns["receptor-ligand"]["interactions"] if min_perc is None: min_perc = 0 # overwrite interaction_score if interaction_perc: interaction_score = np.percentile(table["interaction_score"], interaction_perc) elif interaction_score is None: interaction_score = min(table["interaction_score"]) - 1 subset = table[ (table["receptor_percent"] >= min_perc) & (table["ligand_percent"] >= min_perc) & (table["interaction_score"] > interaction_score) ] if group_a and group_b: subset = subset[(subset["receptor_cluster"].isin(group_a) & subset["ligand_cluster"].isin(group_b)) | (subset["receptor_cluster"].isin(group_b) & subset["ligand_cluster"].isin(group_a))] elif group_a or group_b: group = group_a if group_a else group_b subset = subset[subset["receptor_cluster"].isin(group) | subset["ligand_cluster"].isin(group)] if save: subset.to_csv(f"{settings.table_dir}/{save}", sep='\t', index=False) return subset
@beartype def _check_interactions(anndata: sc.AnnData): """Return error message if anndata object doesn't contain interaction data.""" # is interaction table available? if "receptor-ligand" not in anndata.uns.keys() or "interactions" not in anndata.uns["receptor-ligand"].keys(): raise ValueError("Could not find interaction data! Please setup with `calculate_interaction_table(...)` before running this function.")