""""
Miscellaneous utilities.
"""
import os
import sys
import time
import pickle
from shutil import rmtree
import logging
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
##
_cell_filters = ['filter1', 'filter2']
_var_filters = [
'baseline',
'CV',
'miller2022',
'weng2024',
'MQuad',
'MiTo',
'GT_enriched'
# DEPRECATED
# 'ludwig2019',
# 'velten2021',
# 'seurat',
# 'MQuad_optimized',
# 'density',
# 'GT_stringent'
]
# Try to find assets directory in multiple locations
def _find_assets_path():
# First try relative path for development
dev_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../../assets')
if os.path.exists(dev_path):
return dev_path
# Try conda environment assets directory
import sys
if hasattr(sys, 'prefix'):
conda_assets = os.path.join(sys.prefix, 'assets')
if os.path.exists(conda_assets):
return conda_assets
# Try installed package location in site-packages
import site
for site_dir in site.getsitepackages():
assets_path = os.path.join(site_dir, 'assets')
if os.path.exists(assets_path):
return assets_path
# Fallback to user site directory
user_assets = os.path.join(site.getusersitepackages(), 'assets')
if os.path.exists(user_assets):
return user_assets
# If nothing found, return the development path anyway
return dev_path
path_assets = _find_assets_path()
##
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s' # Custom format
)
##
class TimerError(Exception):
"""
A custom exception used to report errors in use of Timer class.
"""
class Timer:
"""
A custom Timer class.
"""
def __init__(self):
self._start_time = None
def start(self):
"""
Start a new timer.
"""
if self._start_time is not None:
raise TimerError(f"Timer is running. Use .stop() to stop it")
self._start_time = time.perf_counter()
def stop(self, pretty=True):
"""
Stop the timer, and report the elapsed time.
"""
if self._start_time is None:
raise TimerError(f"Timer is not running. Use .start() to start it")
elapsed_time = time.perf_counter() - self._start_time
self._start_time = None
if pretty:
if elapsed_time > 100:
unit = 'min'
elapsed_time = elapsed_time / 60
elif elapsed_time > 1000:
unit = 'h'
elapsed_time = elapsed_time / 3600
else:
unit = 's'
formatted_time = f'{round(elapsed_time, 2)} {unit}'
else:
formatted_time = round(elapsed_time, 2)
return formatted_time
##
def make_folder(path, name, overwrite=True):
"""
A function to create a new {name} folder at the {path} path.
"""
os.chdir(path)
if not os.path.exists(name) or overwrite:
rmtree(os.path.join(path, name), ignore_errors=True)
os.makedirs(name)
else:
pass
##
def update_params(d_original, d_passed):
for k in d_passed:
if k in d_original:
pass
else:
print(f'{k}:{d_passed[k]} kwargs added...')
d_original[k] = d_passed[k]
return d_original
##
def one_hot_from_labels(y):
"""
My one_hot encoder from a categorical variable.
"""
if len(y.categories) > 2:
Y = np.concatenate(
[ np.where(y == x, 1, 0)[:, np.newaxis] for x in y.categories ],
axis=1
)
else:
Y = np.where(y == y.categories[0], 1, 0)
return Y
##
def rescale(x):
"""
Max/min rescaling.
"""
if np.min(x) != np.max(x):
return (x - np.min(x)) / (np.max(x) - np.min(x))
else:
return x
##
def ji(x, y):
"""
Jaccard Index between two list-like objs.
"""
x = set(x)
y = set(y)
ji = len(x&y) / len(x|y)
return ji
##
def flatten_dict(d):
result = {}
for key, value in d.items():
if isinstance(value, dict):
result.update(flatten_dict(value))
else:
result[key] = value
return result
##
def format_tuning(path_tuning):
"""
Format tuning dataframe.
"""
assert os.path.exists(path_tuning)
options = pd.read_csv(os.path.join(path_tuning, 'all_options_final.csv'))
metrics = pd.read_csv(os.path.join(path_tuning, 'all_metrics_final.csv'))
df = pd.merge(
options.pivot(index=['sample', 'job_id'], values='value', columns='option').reset_index(),
metrics.pivot(index=['sample', 'job_id'], values='value', columns='metric').reset_index(),
on=['sample', 'job_id']
)
options = options['option'].unique().tolist()
metrics = metrics['metric'].unique().tolist()
return df, metrics, options
##
def extract_kwargs(args, only_tree=False):
"""
Extract preprocessing parameters from CLI and tuning information.
"""
path_tuning = args.path_tuning if hasattr(args, 'path_tuning') else None
if path_tuning is not None and args.job_id is not None:
path_options = os.path.join(path_tuning, 'all_options_final.csv')
if os.path.exists(path_options):
df_options = pd.read_csv(path_options).loc[lambda x: x['job_id'] == args.job_id]
d = { k:v for k,v in zip(df_options['option'],df_options['value']) }
if not only_tree:
cell_filter = d['cell_filter']
min_cell_number = int(d['min_cell_number'])
lineage_column = d['lineage_column']
filtering = d['filtering']
bin_method = d['bin_method']
metric = d['metric']
min_n_var = int(d['min_n_var'])
filter_dbs = d['filter_dbs']
filter_moran = d['filter_moran']
kwargs = {
'min_cell_number' : min_cell_number,
'lineage_column' : lineage_column,
'filtering' : filtering,
'bin_method' : bin_method,
'min_n_var' : min_n_var,
'ncores' : args.ncores,
'metric' : metric,
'spatial_metrics' : args.spatial_metrics,
'filter_moran' : filter_moran
}
filtering_kwargs = {
'min_cov' : int(d['min_cov']),
'min_var_quality': int(d['min_var_quality']),
'min_frac_negative' : float(d['min_frac_negative']),
'min_n_positive' : int(d['min_n_positive']),
'af_confident_detection' : float(d['af_confident_detection']),
'min_n_confidently_detected' : int(d['min_n_confidently_detected']),
'min_mean_AD_in_positives' : float(d['min_mean_AD_in_positives']),
'min_mean_DP_in_positives' : float(d['min_mean_DP_in_positives'])
}
filtering_kwargs = filtering_kwargs if kwargs['filtering'] == 'MiTo' else {}
binarization_kwargs = {
't_prob' : float(d['t_prob']),
't_vanilla' : float(d['t_vanilla']),
'min_AD' : int(d['min_AD']),
'min_cell_prevalence' : float(d['min_cell_prevalence']),
'k' : int(d['k']),
'gamma' : float(d['gamma']),
'resample' : False
}
tree_kwargs = {'solver':d['solver'], 'metric':d['metric']}
else:
cell_filter = None; kwargs = None;
filtering_kwargs = None; binarization_kwargs = None
tree_kwargs = {'solver':d['solver'], 'metric':d['metric']}
else:
raise ValueError(f'{path_options} does not exists!')
else:
if not only_tree:
cell_filter = args.cell_filter
kwargs = {
'min_cell_number' : args.min_cell_number,
'lineage_column' : args.lineage_column,
'filtering' : args.filtering if args.filtering in _var_filters else None,
'bin_method' : args.bin_method,
'min_n_var' : args.min_n_var,
'filter_dbs' : True if args.filter_dbs == 'true' else False,
'ncores' : args.ncores,
'metric' : args.metric,
'spatial_metrics' : True if args.spatial_metrics == 'true' else False,
'filter_moran' : True if args.filter_moran == 'true' else False,
}
filtering_kwargs = {
'min_cov' : args.min_cov,
'min_var_quality': args.min_var_quality,
'min_frac_negative' : args.min_frac_negative,
'min_n_positive' : args.min_n_positive,
'af_confident_detection' : args.af_confident_detection,
'min_n_confidently_detected' : args.min_n_confidently_detected,
'min_mean_AD_in_positives' : args.min_mean_AD_in_positives,
'min_mean_DP_in_positives' : args.min_mean_DP_in_positives
}
filtering_kwargs = filtering_kwargs if kwargs['filtering'] == 'MiTo' else {}
binarization_kwargs = {
't_prob' : args.t_prob,
't_vanilla' : args.t_vanilla,
'min_AD' : args.min_AD,
'min_cell_prevalence' : args.min_cell_prevalence,
'k' : args.k,
'gamma' : args.gamma,
'resample' : False
}
tree_kwargs = {'solver':args.solver, 'metric':args.metric}
else:
cell_filter = None; kwargs = None;
filtering_kwargs = None; binarization_kwargs = None
tree_kwargs = {'solver':args.solver, 'metric':args.metric}
return cell_filter, kwargs, filtering_kwargs, binarization_kwargs, tree_kwargs
##
def rank_items(df, groupings, metrics, weights, metric_annot):
df_agg = df.groupby(groupings, dropna=False)[metrics].mean().reset_index()
for metric_type in metric_annot:
colnames = []
for metric in metric_annot[metric_type]:
colnames.append(f'{metric}_rescaled')
if metric in ['n_dbSNP', 'n_REDIdb']:
df_agg[metric] = -df_agg[metric]
df_agg[f'{metric}_rescaled'] = (df_agg[metric] - df_agg[metric].min()) / \
(df_agg[metric].max() - df_agg[metric].min())
x = df_agg[colnames].mean(axis=1)
df_agg[f'{metric_type} score'] = (x - x.min()) / (x.max() - x.min())
x = np.sum(df_agg[ [ f'{k} score' for k in metric_annot ] ] * np.array([ weights[k] for k in metric_annot ]), axis=1)
df_agg['Overall score'] = (x - x.min()) / (x.max() - x.min())
df_agg = df_agg.sort_values('Overall score', ascending=False)
return df_agg
##
def load_mut_spectrum_ref():
df = pd.read_csv(os.path.join(path_assets, 'weng2024_mut_spectrum_ref.csv'), index_col=0)
return df
##
def load_mt_gene_annot():
df = pd.read_csv(os.path.join(path_assets, 'formatted_table_wobble.csv'), index_col=0)
df['mut'] = df['Position'].astype(str) + '_' + df['Reference'] + '>' + df['Variant']
return df
##
def load_common_dbSNP():
common = pd.read_csv(os.path.join(path_assets, 'dbSNP_MT.txt'), index_col=0, sep='\t')
common = common['pos'].astype('str') + '_' + common['REF'] + '>' + common['ALT'].map(lambda x: x.split('|')[0])
common = common.to_list()
return common
##
def load_edits_REDIdb():
edits = pd.read_csv(os.path.join(path_assets, 'REDIdb_MT.txt'), index_col=0, sep='\t')
edits = edits.query('nSamples>100')
edits = edits['Position'].astype('str') + '_' + edits['Ref'] + '>' + edits['Ed']
edits = edits.to_list()
return edits
##
[docs]
def subsample_afm(afm, n_clones=3, ncells=100, freqs=np.array([.3,.3,.4])):
assert 1-np.array(freqs).sum() <= .05
assert len(freqs) == n_clones
clones_sorted = afm.obs['GBC'].value_counts().index
clones = clones_sorted[:n_clones].to_list()
cells = []
for clone, f in zip(clones, freqs):
afm_clone = afm[afm.obs.query('GBC==@clone').index,:].copy()
afm_clone = afm_clone[(afm_clone.layers['bin']>0).sum(axis=1).flatten()>2,
(afm_clone.layers['bin']>0).sum(axis=0).flatten()>=2]
n_cells_clone = min(round(ncells*f), afm_clone.shape[0])
cells.extend(
np.random.choice(afm_clone.obs_names, n_cells_clone, replace=False).tolist()
)
afm_subsample = afm[cells].copy()
return afm_subsample
##
def select_jobs(df, sample, n_cells, n_GBC_groups, frac_unassigned):
"""
Select jobs, and choose one for clonal inference benchmarking
"""
df_selected = (
df.loc[
(df['sample'] == sample) & \
(df['n_cells'] >= n_cells) & \
(df['n_GBC_groups'] >= n_GBC_groups) & \
(df['frac_unassigned'] <= frac_unassigned)
]
)
df_selected = (
df_selected[[
'job_id', 'pp_method', 'bin_method', 'af_confident_detection', 'min_cell_number', 'metric',
'ARI', 'corr', 'NMI', 'AUPRC', 'n_cells', 'unassigned', 'n_vars', 'n_GBC_groups', 'n MiTo clone',
]]
)
df_final = df_selected.sort_values('ARI', ascending=False).head(5)
return df_selected, df_final
##
def extract_bench_df(path):
L = []
for folder,_,files in os.walk(path):
for file in files:
if file.endswith('pickle'):
with open(os.path.join(folder, file), 'rb') as f:
d = pickle.load(f)
d['n_inferred'] = d['labels'].loc[lambda x: ~x.isna()].unique().size
del d['labels']
L.append(d)
df_bench = pd.DataFrame(L)
return df_bench
def perturb_AD_counts(a, perc_sites=.75, theta=1, add=True):
"""
Perturb AD and .X layers of afm.
"""
afm = a.copy()
AD_new = afm.layers['AD'].copy()
n_vars = AD_new.shape[1]
n_sites = int(np.round(n_vars * perc_sites))
idx = np.random.choice(np.arange(n_vars), n_sites)
for i in idx:
ad = afm.layers['AD'][:,i].toarray().flatten()
dp = afm.layers['site_coverage'][:,i].toarray().flatten()
p_fit = np.sum(ad) / np.sum(dp)
p_noise = theta * p_fit
if add:
new_ad = ad + (dp * p_noise)
else:
new_ad = ad - (dp * p_noise)
AD_new[:,i] = new_ad
corr = np.corrcoef(afm.layers['AD'].toarray().flatten(), AD_new.toarray().flatten())[0,1]
afm.layers['AD'] = csr_matrix(AD_new)
afm.X = csr_matrix(AD_new / (afm.layers['DP'].toarray() + .000001))
return afm, corr
##