Source code for dynamo.plot.state_graph

import numpy as np
import pandas as pd
from anndata import AnnData
from typing import Optional, Union
from matplotlib.axes import Axes

from ..tools.utils import update_dict
from .scatters import scatters, docstrings
from .utils import save_fig

docstrings.delete_params("scatters.parameters", "aggregate", "kwargs", "save_kwargs")


def create_edge_patch(posA, posB, width=1, node_rad=0, connectionstyle="arc3, rad=0.25", facecolor="k", **kwargs):
    import matplotlib.patches as pat

    style = "simple,head_length=%d,head_width=%d,tail_width=%d" % (
        10,
        10,
        3 * width,
    )
    return pat.FancyArrowPatch(
        posA=posA,
        posB=posB,
        arrowstyle=style,
        connectionstyle=connectionstyle,
        facecolor=facecolor,
        shrinkA=node_rad,
        shrinkB=node_rad,
        **kwargs,
    )


def create_edge_patches_from_markov_chain(
    P,
    X,
    width=3,
    node_rad=0,
    tol=1e-7,
    connectionstyle="arc3, rad=0.25",
    facecolor="k",
    edgecolor="k",
    alpha=0.8,
    **kwargs
):
    """
    create edge patches from a markov chain transition matrix. If P[i, j] > tol, an arrow is created from
    node i to j.
    """
    arrows = []
    for i in range(P.shape[0]):
        for j in range(P.shape[0]):
            if P[i, j] > tol:
                if type(facecolor) == str:
                    fc = facecolor
                else:
                    if type(facecolor) == pd.DataFrame:
                        fc = facecolor.iloc[i, j]
                    else:
                        fc = facecolor[i, j]

                if type(edgecolor) == str:
                    ec = edgecolor
                else:
                    if type(edgecolor) == pd.DataFrame:
                        ec = edgecolor.iloc[i, j]
                    else:
                        ec = edgecolor[i, j]

                if type(alpha) == float:
                    ac = alpha * min(2 * P[i, j], 1)
                else:
                    if type(alpha) == pd.DataFrame:
                        ac = alpha.iloc[i, j]
                    else:
                        ac = alpha[i, j]

                arrows.append(
                    create_edge_patch(
                        X[i],
                        X[j],
                        width=P[i, j] * width,
                        node_rad=node_rad,
                        connectionstyle=connectionstyle,
                        facecolor=fc,
                        edgecolor=ec,
                        alpha=ac,
                        **kwargs,
                    )
                )
    return arrows


[docs]@docstrings.with_indent(4) def state_graph( adata: AnnData, group: Optional[str] = None, transition_threshold: float = 0.001, keep_only_one_direction: bool = True, edge_scale: float = 1, state_graph: Union[None, np.ndarray] = None, edgecolor: Union[None, np.ndarray, pd.DataFrame] = None, facecolor: Union[None, np.ndarray, pd.DataFrame] = None, graph_alpha: Union[None, np.ndarray, pd.DataFrame] = None, basis: str = "umap", x: int = 0, y: int = 1, color: str = "ntr", layer: str = "X", highlights: Optional[list] = None, labels: Optional[list] = None, values: Optional[list] = None, theme: Optional[str] = None, cmap: Optional[str] = None, color_key: Union[dict, list] = None, color_key_cmap: Optional[str] = None, background: Optional[str] = None, ncols: int = 4, pointsize: Union[None, float] = None, figsize: tuple = (6, 4), show_legend: bool = True, use_smoothed: bool = True, show_arrowed_spines: bool = False, ax: Optional[Axes] = None, sort: str = "raw", frontier: bool = False, save_show_or_return: str = "show", save_kwargs: dict = {}, s_kwargs_dict: dict = {"alpha": 1}, **kwargs ): """Plot a summarized cell type (state) transition graph. This function tries to create a model that summarizes the possible cell type transitions based on the reconstructed vector field function. Parameters ---------- group: `str` or `None` (default: `None`) The column in adata.obs that will be used to aggregate data points for the purpose of creating a cell type transition model. transition_threshold: `float` (default: 0.001) The threshold of cell fate transition. Transition will be ignored if below this threshold. keep_only_one_direction: `bool` (default: True) Whether to only keep the higher transition between two cell type. That is if the transition rate from A to B is higher than B to A, only edge from A to B will be plotted. edge_scale: `float` (default: 1) The scaler that can be used to scale the edge width of drawn transition graph. state_graph: `np.ndarray`, `pd.DataFrame` or `None` (default: None) The lumped transition graph between cell states (e.g. cell clusters or types). edgecolor: `np.ndarray`, `pd.DataFrame` or `None` (default: None) The edge color of the arcs that corresponds to the lumped transition graph between cell states. facecolor: `np.ndarray`, `pd.DataFrame` or `None` (default: None) The edge color of the arcs that corresponds to the lumped transition graph between cell states. graph_alpha: `np.ndarray`, `pd.DataFrame` or `None` (default: None) The alpha of the arcs that corresponds to the lumped transition graph between cell states. %(scatters.parameters.no_aggregate|kwargs|save_kwargs)s save_kwargs: `dict` (default: `{}`) A dictionary that will passed to the save_fig function. By default it is an empty dictionary and the save_fig function will use the {"path": None, "prefix": 'state_graph', "dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a dictionary that properly modify those keys according to your needs. s_kwargs_dict: `dict` (default: {"alpha": 1}) The dictionary of the scatter arguments. Returns ------- Plot the a model of cell fate transition that summarizes the possible lineage commitments between different cell types. """ import matplotlib.pyplot as plt from matplotlib import rcParams from matplotlib.colors import to_hex aggregate = group points = adata.obsm["X_" + basis][:, [x, y]] unique_group_obs = adata.obs[group].unique() if type(unique_group_obs) is np.ndarray: groups, uniq_grp = adata.obs[group], unique_group_obs.tolist() elif type(unique_group_obs) is pd.Series: groups, uniq_grp = adata.obs[group], unique_group_obs.to_list() else: groups, uniq_grp = adata.obs[group], list(unique_group_obs) group_median = np.zeros((len(uniq_grp), 2)) # grp_size = adata.obs[group].value_counts()[uniq_grp].values # s_kwargs_dict.update({"s": grp_size}) if state_graph is None: Pl = adata.uns[group + "_graph"]["group_graph"] if keep_only_one_direction: Pl[Pl - Pl.T < 0] = 0 if transition_threshold is not None: Pl[Pl < transition_threshold] = 0 Pl /= Pl.sum(1)[:, None] * edge_scale else: Pl = state_graph for i, cur_grp in enumerate(uniq_grp): group_median[i, :] = np.nanmedian(points[np.where(groups == cur_grp)[0], :2], 0) if background is None: _background = rcParams.get("figure.facecolor") background = to_hex(_background) if type(_background) is tuple else _background plt.figure(facecolor=_background) axes_list, color_list, font_color = scatters( adata=adata, basis=basis, x=x, y=y, color=color, layer=layer, highlights=highlights, labels=labels, values=values, theme=theme, cmap=cmap, color_key=color_key, color_key_cmap=color_key_cmap, background=background, ncols=ncols, pointsize=pointsize, figsize=figsize, show_legend=show_legend, use_smoothed=use_smoothed, aggregate=aggregate, show_arrowed_spines=show_arrowed_spines, ax=ax, sort=sort, save_show_or_return="return", frontier=frontier, **s_kwargs_dict, return_all=True, ) edgecolor = "k" if edgecolor is None else edgecolor facecolor = "k" if facecolor is None else facecolor graph_alpha = 0.8 if graph_alpha is None else graph_alpha arrows = create_edge_patches_from_markov_chain( Pl, group_median, edgecolor=edgecolor, facecolor=facecolor, alpha=graph_alpha, tol=0.01, node_rad=15 ) if type(axes_list) == list: for i in range(len(axes_list)): for arrow in arrows: axes_list[i].add_patch(arrow) axes_list[i].set_facecolor(background) else: for arrow in arrows: axes_list.add_patch(arrow) axes_list.set_facecolor(background) plt.axis("off") if save_show_or_return == "save": s_kwargs = { "path": None, "prefix": "state_graph", "dpi": None, "ext": "pdf", "transparent": True, "close": True, "verbose": True, } s_kwargs = update_dict(s_kwargs, save_kwargs) save_fig(**s_kwargs) elif save_show_or_return == "show": if show_legend: plt.subplots_adjust(right=0.85) plt.tight_layout() plt.show() elif save_show_or_return == "return": return axes_list, color_list, font_color