Source code for mito.pl.diagnostic_plots

"""
Utils and plotting functions to visualize and inspect SNVs from a MAESTER 
experiment and maegatk/mito_preprocessing output.
"""

import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib
import matplotlib.pyplot as plt
import plotting_utils as plu
from typing import Dict, Iterable, Tuple, Any
from anndata import AnnData
from matplotlib.ticker import FixedLocator, FuncFormatter
from ..ut.utils import load_mut_spectrum_ref
from ..ut.positions import MAESTER_genes_positions
from ..pp.filters import mask_mt_sites
from ..pp.preprocessing import annotate_vars


## 


[docs] def vars_AF_spectrum( afm: AnnData, ax: matplotlib.axes.Axes = None, color: str = 'b', **kwargs ) -> matplotlib.axes.Axes: """ Ranked AF distributions (as in Miller et al., 2022). """ X = afm.X.toarray() for i in range(X.shape[1]): x = X[:,i] x = np.sort(x) ax.plot(x, '-', color=color, **kwargs) plu.format_ax(ax=ax, xlabel='Cells (ranked)', ylabel='Allelic Frequency') return ax
##
[docs] def plot_ncells_nAD( afm: AnnData, ax: matplotlib.axes.Axes = None, title: str = None, xticks: Iterable[Any] = None, yticks: str = None, s: float = 5, color: Any = 'k', alpha: float = .7, **kwargs ) -> matplotlib.axes.Axes: """ Plots similar to the one in Weng et al., 2024, followed by the two commentaries from Lareau and Weng, 2024. For each variant, plot the n of positive cells (x-axis) vs mean number of AD in positive cells (y-axis). """ annotate_vars(afm, overwrite=True) ax.plot(afm.var['Variant_CellN'], afm.var['mean_AD_in_positives'], 'o', c=color, markersize=s, alpha=alpha, **kwargs) ax.set_yscale('log', base=2) ax.set_xscale('log', base=2) xticks = [0,1,2,5,10,20,40,80,160,320,640] if xticks is None else xticks yticks = [0,1,2,4,8,16,32,64,132,264] if yticks is None else yticks ax.xaxis.set_major_locator(FixedLocator(xticks)) ax.yaxis.set_major_locator(FixedLocator(yticks)) def integer_formatter(val, pos): return f'{int(val)}' ax.xaxis.set_major_formatter(FuncFormatter(integer_formatter)) ax.yaxis.set_major_formatter(FuncFormatter(integer_formatter)) ax.set(xlabel='n +cells', ylabel='n ALT UMI / +cell', title='' if title is None else title) return ax
##
[docs] def mut_profile( mut_list: Iterable[str], figsize: Tuple[float,float] = (6,3), legend_kwargs: Dict[str,Any] = {} ) -> matplotlib.figure.Figure: """ Re-implementation of MutationProfile_bulk, from Weng et al., 2024). """ ref_df = load_mut_spectrum_ref() called_variants = [ ''.join(x.split('_')) for x in mut_list ] ref_df['called'] = ref_df['variant'].isin(called_variants) total = len(ref_df) total_called = ref_df['called'].sum() grouped = ref_df.groupby(['three_plot', 'group_change', 'strand']) prop_df = grouped.agg( observed_prop_called=('called', lambda x: x.sum() / total_called), expected_prop=('variant', lambda x: x.count() / total), n_obs=('called', 'sum'), n_total=('variant', 'count') ).reset_index() prop_df['fc_called'] = prop_df['observed_prop_called'] / prop_df['expected_prop'] prop_df = prop_df.set_index('three_plot') prop_df['group_change'] = prop_df['group_change'].map(lambda x: '>'.join(list(x))) n = prop_df['group_change'].unique().size fig, axs = plt.subplots( 1, n, figsize=figsize, sharey=True, gridspec_kw={'wspace': 0.1}, constrained_layout=True ) strand_palette = {'H': '#05A8B3', 'L': '#D76706'} for i,x in enumerate(prop_df['group_change'].unique()): ax = axs.ravel()[i] df_ = prop_df.query('group_change==@x') for strand in df_['strand'].unique(): plu.bar( df_.query('strand==@strand').reset_index(), x='three_plot', y='n_obs', color=strand_palette[strand], categorical_cmap = None, width=1, alpha=.5, edgecolor=None, with_label=False, ax=ax ) plu.format_ax( ax, xticks=[], xlabel=x, ylabel='Substitution rate' if i==0 else '', title=f'n: {df_["n_obs"].sum()}' ) plu.add_legend( ax=axs.ravel()[0], colors=strand_palette, ncols=1, loc='upper left', bbox_to_anchor=(0,1), label='Strand', **legend_kwargs ) fig.tight_layout() return fig
##
[docs] def MT_coverage_polar( cov: pd.DataFrame, var_subset: Iterable[str] = None, ax: matplotlib.axes.Axes = None, n_xticks: int = 6, xticks_size: float = 7, yticks_size: float = 2, xlabel_size: float = 6, ylabel_size: float = 9, kwargs_main: Dict[str,Any] = {}, kwargs_subset: Dict[str,Any] = {} ) -> matplotlib.axes.Axes: """ Plot coverage and muts across MT-genome positions. """ kwargs_main_ = {'c':'#494444', 'linestyle':'-', 'linewidth':.7} kwargs_subset_ = {'c':'r', 'marker':'+', 'markersize':10, 'linestyle':''} kwargs_main_.update(kwargs_main) kwargs_subset_.update(kwargs_subset) x = cov.mean(axis=0) theta = np.linspace(0, 2*np.pi, len(x)) ticks = [ int(round(x)) \ for x in np.linspace(1, cov.shape[1], n_xticks) ][:7] ax.plot(theta, np.log10(x), **kwargs_main_) if var_subset is not None: var_pos = var_subset.map(lambda x: int(x.split('_')[0])) test = x.index.isin(var_pos) print(test.sum()) ax.plot(theta[test], np.log10(x[test]), **kwargs_subset_) ax.set_theta_offset(np.pi/2) ax.set_xticks(np.linspace(0, 2*np.pi, n_xticks-1, endpoint=False))#, fontsize=1) ax.set_xticklabels(ticks[:-1], fontsize=xticks_size) ax.set_yticklabels([]) for tick in np.arange(-1,4,1): ax.text(0, tick, str(tick), ha='center', va='center', fontsize=yticks_size) ax.text(0, 1.5, 'n UMIs', ha='center', va='center', fontsize=xlabel_size, color='black') ax.text(np.pi, 4, 'Position (bp)', ha='center', va='center', fontsize=ylabel_size, color='black') ax.spines['polar'].set_visible(False) return ax
##
[docs] def MT_coverage_by_gene_polar( cov: pd.DataFrame, sample: str = None, subset: Iterable[str] = None, ax: matplotlib.axes.Axes = None ) -> matplotlib.axes.Axes: """ Plot coverage and muts across MT-genome positions, with annotated genes. """ if subset is not None: cov = cov.query('cell in @subset') cov['pos'] = pd.Categorical(cov['pos'], categories=range(1,16569+1)) cov = cov.pivot_table(index='cell', columns='pos', values='n', dropna=False, fill_value=0) df_mt = ( pd.DataFrame( MAESTER_genes_positions, columns=['gene', 'start', 'end'] ) .set_index('gene') .sort_values('start') ) x = cov.mean(axis=0) median_target = cov.loc[:,mask_mt_sites(cov.columns)].median(axis=0).median() median_untarget = cov.loc[:,~mask_mt_sites(cov.columns)].median(axis=0).median() theta = np.linspace(0, 2*np.pi, cov.shape[1]) colors = { k:v for k,v in zip(df_mt.index, sc.pl.palettes.default_102[:df_mt.shape[0]])} ax.plot(theta, np.log10(x), '-', linewidth=.7, color='grey') idx = np.arange(1,x.size+1) for gene in colors: start, stop = df_mt.loc[gene, ['start', 'end']].values test = (idx>=start) & (idx<=stop) ax.plot(theta[test], np.log10(x[test]), color=colors[gene], linewidth=1.5) ticks = [ int(round(x)) for x in np.linspace(1, cov.shape[1], 8) ][:7] ax.set_theta_offset(np.pi/2) ax.set_xticks(np.linspace(0, 2*np.pi, 7, endpoint=False)) ax.set_xticklabels(ticks) ax.xaxis.set_tick_params(labelsize=7) ax.yaxis.set_tick_params(labelsize=7) ax.set_rlabel_position(0) ax.set(xlabel='Position (bp)', title=f'{sample}\nTarget: {median_target:.2f}, untarget: {median_untarget:.2f}') return ax
##