"""Module for type checking functions."""
import re
import os
import sys
import importlib
import matplotlib
import numpy as np
import gzip
import shutil
import scanpy as sc
import pandas as pd
from beartype.typing import Optional, Tuple, Any, Iterable
from beartype import beartype
import numpy.typing as npt
import sctoolbox.utils as utils
from sctoolbox._settings import settings
logger = settings.logger
def check_module(module: str) -> None:
Check if <module> can be imported without error.
module : str
Name of the module to check.
If the module is not available for import.
error = 0
except ModuleNotFoundError:
error = 1
except Exception:
raise # unexpected error loading module
# Write out error if module was not found
if error == 1:
s = f"ERROR: Could not find the '{module}' module on path, but the module is needed for this functionality. Please install this package to proceed."
raise ImportError(s)
def _is_interactive() -> bool:
Check if matplotlib backend is interactive.
True if interactive, False otherwise.
backend = matplotlib.get_backend()
if backend == 'module://ipympl.backend_nbagg':
return True
return False
def _add_path() -> str:
Add python executables path to environment variable PATH.
Path to python executables.
python_exec_dir = os.path.dirname(sys.executable) # get path to python executable
if python_exec_dir not in os.environ['PATH']: # check if path is already in environment variable
os.environ['PATH'] += os.pathsep + python_exec_dir # add python executable path to environment variable
return python_exec_dir
return python_exec_dir
# ------------------------- Type checking ------------------------ #
def _is_gz_file(filepath: str) -> bool:
Check wheather file is a compressed .gz file.
filepath : str
Path to file.
True if the file is a compressed .gz file.
with open(filepath, 'rb') as test_f:
return test_f.read(2) == b'\x1f\x8b'
def gunzip_file(f_in: str, f_out: str) -> None:
Decompress file.
f_in : str
Path to compressed input file.
f_out : str
Destination to decompressed output file.
with gzip.open(f_in, 'rb') as h_in:
with open(f_out, 'wb') as h_out:
shutil.copyfileobj(h_in, h_out)
def is_str_numeric(ans: str) -> bool:
Check if string can be converted to number.
ans : str
String to check.
True if string can be converted to float.
return True
except ValueError:
return False
def var_index_from(adata: sc.AnnData,
from_column: Optional[str] = None) -> None:
Format adata.var index from specified column or from the index available.
This formats the index of adata.var according to the pattern ["chr", "start", "stop"].
The adata is changed inplace.
adata : sc.AnnData
The anndata object to reformat.
from_column : Optional[str], default None
Column name in adata.var to be set as index.
if from_column is None:
entry = adata.var.index[0]
index_type = get_index_type(entry)
if index_type == 'snapatac':
adata.var['name'] = adata.var['name'].str.replace("b'", "")
adata.var['name'] = adata.var['name'].str.replace("'", "")
# split the peak column into chromosome start and end
adata.var[['peak_chr', 'start_end']] = adata.var['name'].str.split(':', expand=True)
adata.var[['peak_start', 'peak_end']] = adata.var['start_end'].str.split('-', expand=True)
# set types
adata.var['peak_chr'] = adata.var['peak_chr'].astype(str)
adata.var['peak_start'] = adata.var['peak_start'].astype(int)
adata.var['peak_end'] = adata.var['peak_end'].astype(int)
# remove start_end column
adata.var.drop('start_end', axis=1, inplace=True)
adata.var = adata.var.set_index('name')
elif index_type == "start_name":
coordinate_pattern = r"(chr[0-9XYM]+)+[\_\:\-]+[0-9]+[\_\:\-]+[0-9]+"
new_index = []
for line in adata.var.index:
new_index.append(re.search(coordinate_pattern, line).group(0))
adata.var['new_index'] = new_index
adata.var.set_index('new_index', inplace=True)
entry = list(adata.var[from_column])[0]
index_type = get_index_type(entry)
if index_type == 'snapatac':
adata.var['name'] = adata.var['name'].str.replace("b'", "")
adata.var['name'] = adata.var['name'].str.replace("'", "")
# split the peak column into chromosome start and end
adata.var[['peak_chr', 'start_end']] = adata.var['name'].str.split(':', expand=True)
adata.var[['peak_start', 'peak_end']] = adata.var['start_end'].str.split('-', expand=True)
# set types
adata.var['peak_chr'] = adata.var['peak_chr'].astype(str)
adata.var['peak_start'] = adata.var['peak_start'].astype(int)
adata.var['peak_end'] = adata.var['peak_end'].astype(int)
# remove start_end column
adata.var.drop('start_end', axis=1, inplace=True)
adata.var = adata.var.set_index('name')
elif index_type == "start_name":
coordinate_pattern = r"(chr[0-9XYM]+)+[\_\:\-]+[0-9]+[\_\:\-]+[0-9]+"
new_index = []
for line in adata.var[from_column]:
new_index.append(re.search(coordinate_pattern, line).group(0))
adata.var['new_index'] = new_index
adata.var.set_index('new_index', inplace=True)
def get_index_type(entry: str) -> Optional[str]:
Check the format of the index by regex.
entry : str
String to identify the format on.
The index format. Either 'snapatac', 'start_name' or None for unknown format.
regex_snapatac = r"^b'(chr[0-9]+)+'[\_\:\-]+[0-9]+[\_\:\-]+[0-9]+" # matches: b'chr1':12324-56757
regex_start_name = r"^.+(chr[0-9]+)+[\_\:\-]+[0-9]+[\_\:\-]+[0-9]+" # matches: some_name-chr1:12343-76899
if re.match(regex_snapatac, entry):
return 'snapatac'
if re.match(regex_start_name, entry):
return 'start_name'
def validate_regions(adata: sc.AnnData,
coordinate_columns: Iterable[str]) -> None:
Check if the regions in adata.var are valid.
adata : sc.AnnData
AnnData object containing the regions to be checked.
coordinate_columns : Iterable[str]
List of length 3 for column names in adata.var containing chr, start, end coordinates.
If invalid regions are detected.
# Test whether the first three columns are in the right format
chr, start, end = coordinate_columns
# Test if coordinate columns are in adata.var
utils.check_columns(adata.var, coordinate_columns, name="adata.var")
# Test whether the first three columns are in the right format
for _, line in adata.var.to_dict(orient="index").items():
valid = False
if isinstance(line[chr], str) and isinstance(line[start], int) and isinstance(line[end], int):
if line[start] <= line[end]: # start must be smaller than end
valid = True # if all tests passed, the line is valid
if valid is False:
raise ValueError("The region {0}:{1}-{2} is not a valid genome region. Please check the format of columns: {3}".format(line[chr], line[start], line[end], coordinate_columns))
def in_range(value: int | float, limits: Tuple[int | float, int | float],
include_limits: bool = True) -> bool:
Check if a value is in a given range.
value : int | float
Number to check if in range.
limits : Tuple[int | float, int | float]
Lower and upper limits. E.g. (0, 10)
include_limits : bool, default True
If True includes limits in accepted range.
Returns whether the value is between the set limits.
.. exec_code::
# --- hide: start ---
import sctoolbox.utilities as utils
# --- hide: stop ---
limit = (0.5, 1)
value = 0.5
print(utils.in_range(value=value, limits=limit, include_limits=True))
This will return 'True'; the value is in between the limits including the limits.
if include_limits:
return value >= limits[0] and value <= limits[1]
return value > limits[0] and value < limits[1]
def is_integer_array(arr: npt.ArrayLike) -> bool:
Check if all values of arr are integers.
arr : npt.ArrayLike
Array of values to be checked.
True if all values are integers, False otherwise.
# https://stackoverflow.com/a/7236784
boolean = np.equal(np.mod(arr, 1), 0)
return bool(np.all(boolean))
def check_columns(df: pd.DataFrame,
columns: Iterable[str],
error: bool = True,
name: str = "dataframe") -> Optional[bool]:
Check whether columns are found within a pandas dataframe.
TODO do we need this?
df : pd.DataFrame
A pandas dataframe to check.
columns : Iterable[str]
A list of column names to check for within `df`.
error : bool, default True
If True raise errror if not all columns are found.
If False return true or false
name : str, default dataframe
Dataframe name displayed in the error message.
True or False depending on if columns are in dataframe
None if error is set to True
If any of the columns are not in 'df' and error is set to True.
df_columns = df.columns
not_found = []
for column in columns: # for each column to be checked
if column is not None:
if column not in df_columns:
if len(not_found) > 0:
if error:
error_str = f"Columns '{not_found}' are not found in {name}. Available columns are: {list(df_columns)}"
raise KeyError(error_str)
return False
if not error:
return True
def check_file_ending(file: str,
pattern: str = "gtf") -> None:
Check if a file has a certain file ending.
TODO do we need this?
file : str
Path to the file.
pattern : str, default 'gtf'
File ending to be checked for.
If regex, the regex must match the entire string.
If file does not have the expected file ending.
valid = False
if is_regex:
if re.match(pattern, file):
valid = True
if file.endswith(pattern):
valid = True
if not valid:
raise ValueError(f"File '{file}' does not have the expected file ending '{pattern}'")
def is_regex(regex: str) -> bool:
Check if a string is a valid regex.
regex : str
String to be checked.
True if string is a valid regex, False otherwise.
return True
except re.error:
return False
def check_marker_lists(adata: sc.AnnData,
marker_dict: dict[str, list[str]]) -> dict[str, list[str]]:
Remove genes in custom marker genes lists which are not present in dataset.
adata : sc.AnnData
The anndata object containing features to annotate.
marker_dict : dict[str, list[str]]
A dictionary containing a list of marker genes as values and corresponding cell types as keys.
The marker genes given in the lists need to match the index of adata.var.
dict[str, list[str]]
A dictionary containing a list of marker genes as values and corresponding cell types as keys.
marker_dict = marker_dict.copy()
for key, genes in list(marker_dict.items()):
found_in_var = list(set(adata.var.index) & set(genes))
not_found_in_var = list(set(genes) - set(adata.var.index))
if not found_in_var:
logger.warning(f"No marker in {key} marker list can be found in the data. "
+ "Please check your marker list. Removing empty marker list form dictionary.")
elif not_found_in_var:
marker_dict[key] = found_in_var
logger.info(f"Removed {not_found_in_var} from {key} marker gene list")
return marker_dict
def check_type(obj: Any, obj_name: str, test_type: Any):
Check type of given object.
obj : Any
Object for which the type should be checked
obj_name : str
Object name that would be shown in the error message.
test_type : Any
Type that obj is tested for.
If object type does not match test type.
This function is mostly replaced by beartype.
Only used for types not supported by beartype.
if not isinstance(obj, test_type):
raise TypeError(f"Paramter {obj_name} is required to be of type: "
+ f"{test_type}, but is type: {type(obj)}")