"""
Tree plotting utils.
"""
import logging
import pandas as pd
import scanpy as sc
import plotting_utils as plu
from typing import Iterable, Dict, Any
from cassiopeia.data import CassiopeiaTree
from cassiopeia.plotting.local import utilities as ut
from cassiopeia.plotting.local import *
from matplotlib.patches import Polygon
from .other_plots import *
##
_categorical_cmaps = [sc.pl.palettes.vega_20_scanpy, sc.pl.palettes.default_20, plu.ten_godisnot, 'set1', 'dark']
_continuous_cmaps = ['viridis', 'inferno', 'magma']
_cont_character_cmap = 'mako'
_bin_character_cmap = { 1 : 'r', 0 : 'b', -1 : 'lightgrey', np.nan : 'lightgrey' }
##
def _to_polar_coords(d):
new_d = {}
for k in d:
x, y = d[k]
if not isinstance(x, list):
x = [x]
y = [y]
x, y = ut.polars_to_cartesians(x, y)
new_d[k] = x[0], y[0]
else:
x, y = ut.polars_to_cartesians(x, y)
new_d[k] = x, y
return new_d
##
def _to_polar_colorstrips(L):
new_L = []
for d in L:
new_d = {}
for k in d:
x, y, a, b = d[k]
x, y = ut.polars_to_cartesians(x, y)
new_d[k] = x, y, a, b
new_L.append(new_d)
return new_L
##
def _place_tree_and_annotations(
tree,
features=None,
characters=None,
orient=90,
extend_branches=True,
angled_branches=True,
add_root=True,
continuous_cmaps=None,
cont_character_cmap=None,
categorical_cmaps=None,
bin_character_cmap=None,
layer='raw',
colorstrip_width=None,
colorstrip_spacing=None,
vmin=None,
vmax=None,
vmin_characters=None,
vmax_characters=None
):
"""
Util to set tree elements.
"""
is_polar = isinstance(orient, (float, int))
loc = "polar" if is_polar else orient
# Node and branch coords
node_coords, branch_coords = ut.place_tree(
tree,
orient=orient,
extend_branches=extend_branches,
angled_branches=angled_branches,
add_root=add_root
)
# Colorstrips
anchor_coords = { k:node_coords[k] for k in node_coords if tree.is_leaf(k) }
tight_width, tight_height = compute_colorstrip_size(node_coords, anchor_coords, loc)
width = colorstrip_width or tight_width
spacing = colorstrip_spacing or tight_width / 2
colorstrips = []
features = features if features is not None else []
characters = characters if characters is not None else []
covariates = features + characters
n_cat = 0
# Here we go
for cov in covariates:
# Feature
if cov in features:
if cov in tree.cell_meta.columns:
x = tree.cell_meta[cov].copy()
else:
raise KeyError(f'{cov} not in tree.cell_meta!')
# Character
is_bin_layer = all(x in [1,0,-1] for x in tree.layers[layer].iloc[:,0].unique())
if cov in characters:
if cov in tree.layers[layer].columns:
if is_bin_layer:
x = tree.layers[layer][cov].copy()
x = x.astype('category')
else:
x = tree.layers[layer][cov].copy()
else:
raise KeyError(f'{cov} not in tree.layers[{layer}].')
# Colorstrip specification
if pd.api.types.is_numeric_dtype(x):
if cov in features:
vmin_annot = np.percentile(x, 10) if vmin is None else vmin
vmax_annot = np.percentile(x, 90) if vmax is None else vmax
if continuous_cmaps is None:
continuous_cmap = _continuous_cmaps[0]
elif cov in continuous_cmaps:
continuous_cmap = continuous_cmaps[cov]
else:
raise KeyError(f'{cov} not in continuous_cmaps.')
elif cov in characters:
vmin_annot = vmin_characters
vmax_annot = vmax_characters
continuous_cmap = cont_character_cmap if cont_character_cmap is not None else _cont_character_cmap
colorstrip, anchor_coords = create_continuous_colorstrip(
x.to_dict(),
anchor_coords,
width,
tight_height,
spacing,
loc,
continuous_cmap,
vmin_annot,
vmax_annot
)
elif pd.api.types.is_string_dtype(x) or pd.api.types.is_categorical_dtype(x):
x = x.astype('category')
if cov in features:
if categorical_cmaps is None or cov not in categorical_cmaps:
categorical_cmap = plu.create_palette(tree.cell_meta, cov, _categorical_cmaps[n_cat], add_na=True)
elif cov in categorical_cmaps:
_cmap = categorical_cmaps[cov]
if isinstance(_cmap, str) or isinstance(_cmap, list):
categorical_cmap = plu.create_palette(tree.cell_meta, cov, _cmap, add_na=True)
elif isinstance(_cmap, dict):
categorical_cmap = _cmap
categorical_cmap[np.nan] = 'lightgrey'
else:
raise ValueError(f'''Adjust categorical_cmaps. {cov}:
categorical_cmaps is nor a str, a list or a dict...''')
elif cov in characters:
categorical_cmap = bin_character_cmap if bin_character_cmap is not None else _bin_character_cmap
if not all([ cat in categorical_cmap.keys() for cat in x.unique() ]):
cats = x.unique()
missing_cats = cats[[ cat not in categorical_cmap.keys() for cat in cats ]]
logging.info(f'Missing cats in cmap for meta feat {cov}: {missing_cats}. Adding new colors...')
for i,missing in enumerate(missing_cats):
categorical_cmap[missing] = sc.pl.palettes.godsnot_102[i]
assert (all([ cat in categorical_cmap.keys() for cat in x.unique() ]))
assert categorical_cmap[np.nan] == 'lightgrey'
# Place
boxes, anchor_coords = ut.place_colorstrip(
anchor_coords, width, tight_height, spacing, loc
)
colorstrip = {}
for leaf,value in zip(x.index, x.values):
colorstrip[leaf] = boxes[leaf] + (categorical_cmap[value], f"{leaf}\n{value}")
n_cat += 1
else:
raise ValueError(f'{cov} has {x.dtype} dtype. Check meta and layers...')
colorstrips.append(colorstrip)
# To polar, if necessary
if is_polar:
branch_coords = _to_polar_coords(branch_coords)
node_coords = _to_polar_coords(node_coords)
colorstrips = _to_polar_colorstrips(colorstrips)
# Add feature names as colorstrips labels
colorstrips = [ (c,name) for c,name in zip(colorstrips, covariates) ]
return node_coords, branch_coords, colorstrips
##
def _set_colors(d, meta=None, cov=None, cmap=None, kwargs=None, vmin=None, vmax=None):
"""
Create a dictionary of elements colors.
"""
if meta is not None and cov is not None:
if cov in meta.columns:
x = meta[cov]
if isinstance(cmap, str):
if pd.api.types.is_numeric_dtype(x):
cmap = matplotlib.colormaps[cmap]
cmap = matplotlib.cm.get_cmap(cmap)
if vmin is None or vmax is None:
vmin = np.percentile(x.values, 10)
vmax = np.percentile(x.values, 90)
normalize = plt.Normalize(vmin=vmin, vmax=vmax)
colors = [ cmap(normalize(value)) for value in x ]
colors = { k:v for k, v in zip(x.index, colors)}
elif pd.api.types.is_string_dtype(x):
colors = (
meta[cov]
.map(plu.create_palette(meta, cov, cmap))
.to_dict()
)
elif isinstance(cmap, dict):
print('User-provided colors dictionary...')
colors = meta[cov].astype('str').map(cmap).to_dict()
else:
raise KeyError(f'{cov} You can either specify a string cmap or an element:color dictionary.')
else:
raise KeyError(f'{cov} not present in cell_meta.')
else:
colors = { k : kwargs['c'] for k in d }
return colors
##
[docs]
def plot_tree(
tree: CassiopeiaTree,
ax: matplotlib.axes.Axes = None,
orient: float|str = 90,
extend_branches: bool = True,
angled_branches: bool = True,
add_root: bool = False,
features: Iterable[str] = None,
categorical_cmaps: Dict[str, str|Dict[str,Any]] = None,
continuous_cmaps: Dict[str, str|Dict[str,Any]] = None,
characters: Iterable[str] = None,
cont_character_cmap: str = 'mako',
bin_character_cmap: Dict[str,Any] = None,
layer: str ='raw',
vmin_characters: float = 0,
vmax_characters: float =.05,
colorstrip_spacing: float =.25,
colorstrip_width: float = 1.5,
labels: bool = True,
label_size: float = 10,
label_offset: float = 2,
meta_branches: pd.DataFrame = None,
cov_branches: str = None,
cmap_branches: str|Dict[str,Any] = 'Spectral_r',
cov_leaves: str = None,
cmap_leaves: str|Dict[str,Any] = 'tab20',
feature_internal_nodes: str = None,
cmap_internal_nodes: str|Dict[str,Any] ='Spectral_r',
vmin: float = None,
vmax: float = None,
vmin_internal_nodes: float = .2,
vmax_internal_nodes: float = .8,
vmin_leaves: float = None,
vmax_leaves: float = None,
internal_node_labels: bool = False,
internal_node_subset: Iterable[str] = None,
internal_node_label_size: float = 7,
show_internal: bool = False,
leaves_labels: bool = False,
leaf_label_size: float = 5,
colorstrip_kwargs: Dict[str,Any] = {},
leaf_kwargs: Dict[str,Any] = {},
internal_node_kwargs: Dict[str,Any] = {},
branch_kwargs: Dict[str,Any] = {},
x_space: float = 1.5
) -> matplotlib.axes.Axes:
"""
Plotting function that extends capabilities in cs.plotting.local.plot_matplotlib from
Cassiopeia, MW Jones et al, 2020.
Parameters
----------
tree : CassiopeiaTree
Tree to plot.
ax : matplotlib.axes.Axes, optional
Axes object to draw on. Default is None.
orient : float or str, optional
Tree layout in polar (90) or cartesian coordinates (e.g., "down"). Default is 90.
extend_branches : bool, optional
Equal length branch from leaf to root. Default is True.
angled_branches : bool, optional
Make branches angled, not round. Default is True.
add_root : bool, optional
Add root to tree. Default is False.
features : Iterable[str], optional
Features in tree.cell_meta to plot. Default is None.
categorical_cmaps : dict of {str: str or dict}, optional
Dictionary of colors for categorical features. Default is None.
continuous_cmaps : dict of {str: str or dict}, optional
Dictionary of colors for continuous features. Default is None.
characters : Iterable[str], optional
List of characters to plot. Default is None.
cont_character_cmap : str, optional
Color map for characters ("raw" layer). Default is "mako".
bin_character_cmap : dict, optional
Colors for binary character states ("transformed" layer). Default is None.
layer : str, optional
Layer in tree.layers to plot, if characters is not None. Default is "raw".
vmin_characters : float, optional
Minimum value for character colorbar. Default is 0.
vmax_characters : float, optional
Maximum value for character colorbar. Default is 0.05.
colorstrip_spacing : float, optional
Relative amount of spacing between colorstrips. Default is 0.25.
colorstrip_width : float, optional
Relative colorstrip width. Default is 1.5.
labels : bool, optional
Draw labels for features and characters. Default is True.
label_size : float, optional
Features and character label size. Default is 10.
label_offset : float, optional
Features and character label offset. Default is 2.
meta_branches : pd.DataFrame, optional
Annotation table for branches. Default is None.
cov_branches : str, optional
Branch feature to plot. Default is None.
cmap_branches : str or dict, optional
Color map for branch feature. Default is "Spectral_r".
cov_leaves : str, optional
Leaf feature to plot. Default is None.
cmap_leaves : str or dict, optional
Color map for leaves feature. Default is "tab20".
vmin_leaves : float, optional
Min value for leaves cmap.
vmax_leaves : float, optional
Max value for leaves cmap.
feature_internal_nodes : str, optional
Internal node feature to plot. Default is None.
cmap_internal_nodes : str or dict, optional
Color map for internal nodes feature. Default is "Spectral_r".
vmin_internal_nodes : float, optional
Minimum value for internal node feature colorbar. Default is 0.2.
vmax_internal_nodes : float, optional
Maximum value for internal node feature colorbar. Default is 0.8.
internal_node_labels : bool, optional
Draw internal node names on location. Default is False.
internal_node_subset : Iterable[str], optional
Subset of internal nodes to plot. Default is None.
internal_node_label_size : float, optional
Internal node name/label size. Default is 7.
show_internal : bool, optional
Show internal nodes. Default is False.
leaves_labels : bool, optional
Plot leaves names. Default is False.
leaf_label_size : float, optional
Leaf name/label size. Default is 5.
colorstrip_kwargs : dict, optional
Additional colorstrip keyword arguments. Default is {}.
leaf_kwargs : dict, optional
Additional leaves keyword arguments. Default is {}.
internal_node_kwargs : dict, optional
Additional internal nodes keyword arguments. Default is {}.
branch_kwargs : dict, optional
Additional branch keyword arguments. Default is {}.
Returns
-------
ax : matplotlib.axes.Axes
Axes object.
"""
# Set coord and axis
ax.axis('off')
# Set graphic elements
(
node_coords,
branch_coords,
colorstrips,
) = _place_tree_and_annotations(
tree,
features=features,
characters=characters,
orient=orient,
extend_branches=extend_branches,
angled_branches=angled_branches,
add_root=add_root,
continuous_cmaps=continuous_cmaps,
cont_character_cmap=cont_character_cmap,
categorical_cmaps=categorical_cmaps,
bin_character_cmap=bin_character_cmap,
layer=layer,
colorstrip_width=colorstrip_width,
colorstrip_spacing=colorstrip_spacing,
vmin=vmin,
vmax=vmax,
vmin_characters=vmin_characters,
vmax_characters=vmax_characters
)
##
# Branches
_branch_kwargs = {'linewidth':1, 'c':'k'}
_branch_kwargs.update(branch_kwargs or {})
colors = _set_colors(
branch_coords, meta=meta_branches, cov=cov_branches,
cmap=cmap_branches, kwargs=_branch_kwargs
)
for branch, (xs, ys) in branch_coords.items():
c = colors[branch] if branch in colors else _branch_kwargs['c']
_dict = _branch_kwargs.copy()
_dict.update({'c':c})
ax.plot(xs, ys, **_dict)
##
# Colorstrips
_colorstrip_kwargs = {'linewidth':0, 'alpha':1}
_colorstrip_kwargs.update(colorstrip_kwargs or {})
for colorstrip, feat in colorstrips:
y_positions = []
x_positions = []
for xs, ys, c, _ in colorstrip.values():
_dict = _colorstrip_kwargs.copy()
_dict["facecolor"] = c
polygon = Polygon(xy=list(zip(xs, ys)), closed=True, **_dict)
polygon.set_rasterized(True)
ax.add_patch(polygon)
y_positions.extend(ys)
x_positions.extend(xs)
if orient == 'down' and labels:
y_min = min(y_positions)
y_max = max(y_positions)
y_mid = (y_min + y_max) / 2
x_min = min(x_positions)
x_offset = label_offset
ax.text(
x_min - x_offset, y_mid, feat, ha='right', va='center', fontsize=label_size
)
##
# Leaves
leave_size = 2 if cov_leaves is not None else 0
_leaf_kwargs = {'markersize':leave_size, 'c':'k', 'marker':'o'}
_leaf_kwargs.update(leaf_kwargs or {})
leaves = { node : node_coords[node] for node in node_coords if tree.is_leaf(node) }
colors = _set_colors(
leaves, meta=tree.cell_meta, cov=cov_leaves,
cmap=cmap_leaves, kwargs=_leaf_kwargs,
vmin=vmin_leaves, vmax=vmax_leaves
)
for node in leaves:
_dict = _leaf_kwargs.copy()
x = leaves[node][0]
y = leaves[node][1]
c = colors[node] if node in colors else _leaf_kwargs['c']
_dict.update({'c':c})
ax.plot(x, y, **_dict)
if leaves_labels:
if orient == 'right':
ax.text(
x+x_space, y, str(node), ha='center', va='center',
fontsize=leaf_label_size
)
else:
raise ValueError(
'Correct placement of labels at leaves implemented only for the right orient.'
)
##
# Internal nodes
_internal_node_kwargs = {
'markersize': 0 if internal_node_labels else 2,
'c':'white', 'marker':'o', 'alpha':1,
'markeredgecolor':'k', 'markeredgewidth':1, 'zorder':10
}
_internal_node_kwargs.update(internal_node_kwargs or {})
internal_nodes = {
node : node_coords[node] for node in node_coords \
if tree.is_internal_node(node) and node != 'root'
}
# Subset nodes if necessary
if internal_node_subset is not None:
internal_node_subset = [ x for x in internal_node_subset if x in tree.internal_nodes ]
internal_nodes = { node : internal_nodes[node] for node in internal_nodes if node in internal_node_subset }
if feature_internal_nodes is not None:
s = pd.Series({ node : tree.get_attribute(node, feature_internal_nodes) for node in internal_nodes })
s.loc[lambda x: x.isna()] = 0 # Set missing values to 0
colors = _set_colors(
internal_nodes, meta=s.to_frame(feature_internal_nodes), cov=feature_internal_nodes,
cmap=cmap_internal_nodes, kwargs=_internal_node_kwargs,
vmin=vmin_internal_nodes, vmax=vmax_internal_nodes
)
# else:
# if feature_internal_nodes is None and internal_node_subset is not None:
# for node in tree.internal_nodes:
# colors =
# else:
# raise ValueError('')
for node in internal_nodes:
_dict = _internal_node_kwargs.copy()
x = internal_nodes[node][0]
y = internal_nodes[node][1]
c = colors[node] if node in colors else _internal_node_kwargs['c']
s = _internal_node_kwargs['markersize'] if (node in colors or show_internal) else 0
_dict.update({'c':c, 'markersize':s})
ax.plot(x, y, **_dict)
if internal_node_labels:
if node in colors:
v = tree.get_attribute(node, feature_internal_nodes)
if isinstance(v, float):
v = round(v, 2)
ax.text(
x+.3, y-.1, str(v), ha='center', va='bottom',
bbox=dict(boxstyle='round', alpha=0, pad=10),
fontsize=internal_node_label_size,
)
return ax
##