Source code for pyrovelocity.plot

from typing import Dict
from typing import List
from typing import Tuple

import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyro
import scvelo as scv
import seaborn as sns
import sklearn
import torch
import umap
from adjustText import adjust_text
from anndata import AnnData
from astropy.stats import rayleightest
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numpy import ndarray
from scipy.stats import spearmanr
from scvelo.plotting.velocity_embedding_grid import default_arrow
from sklearn.pipeline import Pipeline

from pyrovelocity.cytotrace import compute_similarity2
from pyrovelocity.utils import ensure_numpy_array
from pyrovelocity.utils import mRNA
from pyrovelocity.utils import mse_loss_sum


[docs]def plot_evaluate_dynamic_orig(adata, gene="Cpe", velocity=None, ax=None): # compute dynamics alpha, beta, gamma, scaling, t_ = ( torch.tensor(adata.var.loc[gene, "fit_alpha"]), torch.tensor(adata.var.loc[gene, "fit_beta"]), torch.tensor(adata.var.loc[gene, "fit_gamma"]), torch.tensor(adata.var.loc[gene, "fit_scaling"]), torch.tensor(adata.var.loc[gene, "fit_t_"]), ) beta_scale = beta * scaling u0, s0 = adata.var.loc[gene, "fit_u0"], adata.var.loc[gene, "fit_s0"] # t = torch.tensor(adata[:, gene].layers['fit_t'][:, 0]).sort()[0] t = torch.tensor(adata[:, gene].layers["fit_t"][:, 0]) u_inf, s_inf = mRNA(t_, u0, s0, alpha, beta_scale, gamma) state = (t < t_).int() tau = t * state + (t - t_) * (1 - state) u0_vec = u0 * state + u_inf * (1 - state) s0_vec = s0 * state + s_inf * (1 - state) alpha_ = 0 alpha_vec = alpha * state + alpha_ * (1 - state) ut, st = mRNA(tau, u0_vec, s0_vec, alpha_vec, beta_scale, gamma) ut = ut * scaling + u0 st = st + s0 xnew = torch.linspace(torch.min(st), torch.max(st)) ynew = gamma / beta * (xnew - torch.min(xnew)) + torch.min(ut) if ax is None: fig, ax = plt.subplots() scv.pl.scatter(adata, gene, color=["clusters"], ax=ax, show=False) else: ax.scatter( st.detach().numpy(), ut.detach().numpy(), linestyle="-", linewidth=5, alpha=0.3, ) ax.plot( xnew.detach().numpy(), ynew.detach().numpy(), color="b", linestyle="--", linewidth=5, ) if velocity is None: print( "scvelo %s mse loss:" % gene, mse_loss_sum( ut, st, adata[:, gene].layers["Mu"].toarray()[:, 0], adata[:, gene].layers["Ms"].toarray()[:, 0], ), ) else: print( "scvelo %s mse loss:" % gene, mse_loss_sum( ut[velocity.weight], st[velocity.weight], adata[:, gene].layers["Mu"].toarray()[velocity.weight, 0], adata[:, gene].layers["Ms"].toarray()[velocity.weight, 0], ), ) return alpha_vec, ut, st, xnew, ynew
[docs]def plot_dynamic_pyro( adata, gene, losses, summary, velocity, fix_param_list, alpha, beta, gamma, scale, t_, t, ): alpha = ( torch.tensor(alpha) if fix_param_list[0] == 1 else pyro.param("AutoDelta.alpha_sample") ) beta = ( torch.tensor(beta) if fix_param_list[1] == 1 else pyro.param("AutoDelta.beta_sample") ) gamma = ( torch.tensor(gamma) if fix_param_list[2] == 1 else pyro.param("AutoDelta.gamma_sample") ) scale = ( torch.tensor(scale) if fix_param_list[3] == 1 else pyro.param("AutoDelta.scale_sample") ) t_ = ( torch.tensor(t_) if fix_param_list[4] == 1 else pyro.param("AutoDelta.switching_sample") ) if fix_param_list[5] == 0: t = pyro.param("AutoDelta.latent_time") else: t = torch.tensor(t) fig, ax = plt.subplots(1, 3) fig.set_size_inches(16, 4) ax[1].scatter( adata[:, gene].layers["fit_t"].toarray()[velocity.weight, 0], t.data.cpu().numpy(), ) t = (t.sort()[0].max() + 1).int() t = torch.linspace(0.0, t, 500) beta_scale = beta * scale u0, s0 = torch.tensor(0.0), torch.tensor(0.0) u_inf, s_inf = mRNA(t_, u0, s0, alpha, beta_scale, gamma) state = (t < t_).int() tau = t * state + (t - t_) * (1 - state) u0_vec = u0 * state + u_inf * (1 - state) s0_vec = s0 * state + s_inf * (1 - state) alpha_ = 0.0 alpha_vec = alpha * state + alpha_ * (1 - state) ut, st = mRNA(tau, u0_vec, s0_vec, alpha_vec, beta_scale, gamma) ut = ut * scale + u0 st = st + s0 xnew = torch.linspace( torch.tensor(st.min().detach().numpy()), torch.tensor(st.max().detach().numpy()), 50, ) ynew = gamma / beta * (xnew - torch.min(xnew)) + torch.min(ut) ax[0].plot(losses) ax[0].set_yscale("log") ax[0].set_title("ELBO") ax[0].set_xlabel("step") ax[0].set_ylabel("loss") scv.pl.scatter(adata, gene, color=["clusters"], ax=ax[2], show=False) ax[2].scatter( summary["x_obs"]["mean"][:, 1], summary["x_obs"]["mean"][:, 0], alpha=0.5, s=5, color="r", ) # ax[2].scatter(summary['x_obs']['5%'][:, 1], summary['x_obs']['5%'][:, 0], alpha=0.2, color='r') # ax[2].scatter(summary['x_obs']['95%'][:, 1], summary['x_obs']['95%'][:, 0], alpha=0.2, color='r') ax[2].plot( st.detach().numpy(), ut.detach().numpy(), linestyle="-", linewidth=5, color="g" ) ax[2].plot( xnew.detach().numpy(), ynew.detach().numpy(), color="g", linestyle="--", linewidth=5, ) # ax[2].set_ylim(0, 3) # ax[2].set_xlim(0, 16) plot_evaluate_dynamic_orig(adata, gene, velocity, ax[2]) print( "pyro model mse loss", mse_loss_sum( summary["x_obs"]["mean"][:, 0], summary["x_obs"]["mean"][:, 1], velocity.x[:, 0], velocity.x[:, 1], ), ) return alpha_vec, ut, st, xnew, ynew
[docs]def plot_multigenes_dynamical( summary, alpha, beta, gamma, t_, t, adata, gene="Cpe", scale=None, ax=None, raw=False, ): pass # softplus operation as pyro # https://stackoverflow.com/questions/44230635/avoid-overflow-with-softplus-function-in-python # t_ = torch.log(1+torch.exp(-np.abs(t_))) + torch.maximum(t_, # torch.zeros(t_.shape)) t = t.sort()[0].max().int() t = torch.linspace(0.0, t, 500) u0, s0 = torch.tensor(0.0), torch.tensor(0.0) # u0, s0 = pyro.param("u0"), pyro.param("s0") u_inf, s_inf = mRNA(t_, u0, s0, alpha, beta, gamma) state = (t < t_).int() tau = t * state + (t - t_) * (1 - state) # tau = torch.log(1+torch.exp(-np.abs(tau))) + torch.maximum(tau, # torch.zeros(tau.shape)) u0_vec = u0 * state + u_inf * (1 - state) s0_vec = s0 * state + s_inf * (1 - state) alpha_ = 0.0 alpha_vec = alpha * state + alpha_ * (1 - state) ut, st = mRNA(tau, u0_vec, s0_vec, alpha_vec, beta, gamma) if scale is None: ut = ut + u0 else: ut = ut * scale + u0 st = st + s0 xnew = torch.linspace( torch.tensor(st.min().detach().numpy()), torch.tensor(st.max().detach().numpy()), 50, ) if scale is not None: ynew = (gamma / beta * (xnew - torch.min(xnew))) * scale + torch.min( ut * scale + u0 ) else: ynew = gamma / beta * (xnew - torch.min(xnew)) + torch.min(ut) if ax is None: fig, ax = plt.subplots() try: if raw: scv.pl.scatter( adata, gene, x="spliced", y="unspliced", color=["clusters"], ax=ax, show=False, ) else: scv.pl.scatter(adata, gene, color=["clusters"], ax=ax, show=False) except: if raw: scv.pl.scatter( adata, gene, x="spliced", y="unspliced", color=["Clusters"], ax=ax, show=False, ) else: scv.pl.scatter(adata, gene, color=["Clusters"], ax=ax, show=False) ax.plot( st.detach().numpy(), ut.detach().numpy(), linestyle="-", linewidth=2.5, color="red", label="Pyro-Velocity", ) ax.plot( xnew.detach().numpy(), ynew.detach().numpy(), color="red", linestyle="--", linewidth=2.5, alpha=0.4, ) if summary is not None: ax.scatter( summary["x_obs"]["mean"][:, 1], summary["x_obs"]["mean"][:, 0], alpha=0.5, color="red", )
[docs]def plot_posterior_time( posterior_samples, adata, ax=None, fig=None, basis="umap", addition=True, position="left", s=3, ): if addition: sns.set_style("white") sns.set_context("paper", font_scale=1) matplotlib.rcParams.update({"font.size": 7}) plt.figure() test_hist = plt.hist( posterior_samples["cell_time"].mean(0), bins=100, label="test" ) plt.xlabel("mean of cell time") plt.ylabel("frequency") plt.title("Histogram of cell time posterior samples") plt.legend() pos_mean_time = posterior_samples["cell_time"].mean(0) # scale to 0-1? adata.obs["cell_time"] = pos_mean_time / pos_mean_time.max() if ax is None: fig, ax = plt.subplots(1, 1) fig.set_size_inches(2.36, 2) im = ax.scatter( adata.obsm[f"X_{basis}"][:, 0], adata.obsm[f"X_{basis}"][:, 1], s=s, alpha=0.4, c=adata.obs["cell_time"], cmap="inferno", linewidth=0, ) set_colorbar(im, ax, labelsize=5, fig=fig, position=position) # ax.arrow(-19, -6, 0, 5, length_includes_head=True, # head_width=1, head_length=1, color='black') ax.axis("off") if "cytotrace" in adata.obs.columns: ax.set_title( "Pyro-Velocity shared time\ncorrelation with Cytotrace: %.2f" % ( spearmanr( adata.obs["cell_time"].values, 1 - adata.obs.cytotrace.values )[0] ), fontsize=7, ) else: ax.set_title("Pyro-Velocity shared time\n", fontsize=7)
[docs]def mae_per_gene(pred_counts: ndarray, true_counts: ndarray) -> ndarray: """Computes mean average error between counts and predicted probabilities.""" error = np.abs(true_counts - pred_counts).sum(-2) total = np.clip(true_counts.sum(-2), 1, np.inf) return -np.array(error / total)
[docs]def compute_volcano_data( posterior_samples, adata, time_correlation_with="s", selected_genes=None, negative=False, ) -> None: assert isinstance(posterior_samples, (tuple, list)) assert isinstance(adata, (tuple, list)) assert "s" in posterior_samples[0] assert "alpha" in posterior_samples[0] maes_list = [] cors = [] genes = [] labels = [] switching = [] for p, ad, label in zip(posterior_samples, adata, ["train", "valid"]): print(label) for sample in range(p["alpha"].shape[0]): maes_list.append( mae_per_gene( p["s"][sample].squeeze(), ensure_numpy_array(ad.layers["raw_spliced"]), ) ) df_genes_cors = compute_similarity2( p[time_correlation_with][sample].squeeze(), p["cell_time"][sample].squeeze().reshape(-1, 1), ) cors.append(df_genes_cors[0]) genes.append(ad.var_names.values) labels.append([f"Poisson_{label}"] * len(ad.var_names.values)) volcano_data = pd.DataFrame( { "mean_mae": np.hstack(maes_list), "label": np.hstack(labels), "time_correlation": np.hstack(cors), "genes": np.hstack(genes), } ) volcano_data = volcano_data.groupby("genes").mean(["mean_mae", "time_correlation"]) volcano_data.loc[:, "mean_mae_rank"] = volcano_data.mean_mae.rank(ascending=False) volcano_data.loc[:, "time_correlation_rank"] = volcano_data.time_correlation.apply( abs ).rank(ascending=False) volcano_data.loc[:, "rank_product"] = ( volcano_data.mean_mae_rank * volcano_data.time_correlation_rank ) if selected_genes is None: genes = ( volcano_data.sort_values("mean_mae", ascending=False) .head(300) .sort_values("time_correlation", ascending=negative) .head(4) .index ) else: genes = selected_genes volcano_data.loc[:, "selected genes"] = 0 volcano_data.loc[genes, "selected genes"] = 1 return volcano_data, genes
import matplotlib.gridspec as gridspec
[docs]def plot_gene_ranking( posterior_samples, adata, ax=None, time_correlation_with="s", selected_genes=None, assemble=False, data="correlation", negative=False, adjust_text_bool=False, show_marginal_histograms=False, ) -> None: if selected_genes is not None: assert isinstance(selected_genes, (tuple, list)) assert isinstance(selected_genes[0], str) volcano_data = posterior_samples[0]["gene_ranking"] genes = selected_genes elif "u" in posterior_samples[0]: volcano_data, genes = compute_volcano_data( posterior_samples, adata, time_correlation_with, selected_genes, negative ) else: volcano_data = posterior_samples[0]["gene_ranking"] genes = posterior_samples[0]["genes"] fig = None if data == "correlation": defaultfontsize = 7 defaultdotsize = 3 plot_title = "Pyro-Velocity genes" if show_marginal_histograms: time_corr_hist, time_corr_bins = np.histogram( volcano_data["time_correlation"], bins="auto", density=False ) mean_mae_hist, mean_mae_bins = np.histogram( volcano_data["mean_mae"], bins="auto", density=False ) fig = plt.figure(figsize=(10, 10)) # ax_scatter = plt.subplot2grid((3, 3), (1, 0), colspan=2, rowspan=2) # ax_hist_x = plt.subplot2grid((3, 3), (0, 0), colspan=2) # ax_hist_y = plt.subplot2grid((3, 3), (1, 2), rowspan=2) gs = gridspec.GridSpec( 3, 3, width_ratios=[2, 2, 1], height_ratios=[1, 2, 2] ) ax_scatter = plt.subplot(gs[1:, :2]) ax_hist_x = plt.subplot(gs[0, :2]) ax_hist_y = plt.subplot(gs[1:, 2]) # time histogram ax_hist_x.bar( time_corr_bins[:-1], time_corr_hist, width=np.diff(time_corr_bins), align="edge", ) # MAE histogram ax_hist_y.barh( mean_mae_bins[:-1], mean_mae_hist, height=np.diff(mean_mae_bins), align="edge", ) ax_hist_x.tick_params(axis="x", labelbottom=False) ax_hist_y.tick_params(axis="y", labelleft=False) defaultfontsize = 14 defaultdotsize = 12 plot_title = "" ax = ax_scatter sns.scatterplot( x="time_correlation", y="mean_mae", hue="selected genes", data=volcano_data, s=defaultdotsize, linewidth=0, ax=ax, legend=False, alpha=0.3, ) ax.set_title(plot_title, fontsize=defaultfontsize) ax.set_xlabel( "shared time correlation\nwith spliced expression", fontsize=defaultfontsize ) ax.set_ylabel("negative mean\nabsolute error", fontsize=defaultfontsize) sns.despine() ax.tick_params(labelsize=defaultfontsize - 1) texts = [] for i, g in enumerate(genes): ax.scatter( volcano_data.loc[g, :].time_correlation, volcano_data.loc[g, :].mean_mae, s=15, color="red", marker="*", ) texts.append( ax.text( volcano_data.loc[g, :].time_correlation, volcano_data.loc[g, :].mean_mae, g, fontsize=defaultfontsize - 2, color="black", ha="center", va="center", ) ) if not assemble: if i % 2 == 0: offset = 10 + i * 5 else: offset = -10 - i * 5 if i % 2 == 0: offset_y = -10 + i * 5 else: offset_y = -10 + i * 5 if not adjust_text_bool: adjust_text( texts, arrowprops=dict(arrowstyle="-", color="red", alpha=0.5), ha="center", va="bottom", ax=ax, ) else: adjust_text( texts, precision=0.001, expand_text=(1.01, 1.05), expand_points=(1.01, 1.05), force_text=(0.01, 0.25), force_points=(0.01, 0.25), arrowprops=dict(arrowstyle="-", color="blue", alpha=0.6), ax=ax, ) else: sns.scatterplot( x="switching", y="mean_mae", hue="selected genes", data=volcano_data, s=3, linewidth=0, ax=ax, legend=False, alpha=0.3, ) ax.set_title("Pyro-Velocity genes", fontsize=7) ax.set_xlabel("gene switching time", fontsize=7) ax.set_ylabel("negative mean\nabsolute error", fontsize=7) sns.despine() return volcano_data, fig
[docs]def denoised_umap(posterior_samples, adata, cell_state="state_info"): pass import sklearn import umap from sklearn.pipeline import Pipeline projection = [ ("PCA", sklearn.decomposition.PCA(random_state=99, n_components=50)), ("UMAP", umap.UMAP(random_state=99, n_components=2)), ] pipelines = Pipeline(projection) fig, ax = plt.subplots(2, 2) fig.set_size_inches(9, 9) expression = [posterior_samples["st"].mean(0)] pipelines.fit(expression[0]) umap_orig = pipelines.transform(expression[0]) adata.obsm["X_umap1"] = umap_orig scv.pl.scatter(adata, basis="umap1", ax=ax[0][0], show=False) joint_pcs = pipelines.steps[0][1].transform(expression[0]) adata.obsm["X_pyropca"] = joint_pcs scv.pp.neighbors(adata, use_rep="pyropca") adata.layers["spliced_pyro"] = posterior_samples["st"].mean(0) if "u_scale" in posterior_samples: adata.layers["velocity_pyro"] = ( posterior_samples["ut"] * posterior_samples["beta"] / (posterior_samples["u_scale"] / posterior_samples["s_scale"]) - posterior_samples["st"] * posterior_samples["gamma"] ).mean(0) else: adata.layers["velocity_pyro"] = ( posterior_samples["ut"] * posterior_samples["beta"] - posterior_samples["st"] * posterior_samples["gamma"] ).mean(0) scv.tl.velocity_graph(adata, vkey="velocity_pyro", xkey="spliced_pyro") scv.tl.velocity_embedding(adata, vkey="velocity_pyro", basis="umap1") scv.pl.velocity_embedding_grid( adata, basis="umap1", vkey="velocity_pyro", density=0.5, scale=0.25, arrow_size=3, color=cell_state, ax=ax[0][1], show=False, ) adata.obsm["X_umap1"] = umap_orig expression = [ np.hstack([posterior_samples["st"].mean(0), posterior_samples["ut"].mean(0)]) ] pipelines.fit(expression[0]) umap_orig = pipelines.transform(expression[0]) adata.obsm["X_umap2"] = umap_orig scv.pl.scatter(adata, basis="umap2", ax=ax[1][0], show=False) joint_pcs = pipelines.steps[0][1].transform(expression[0]) adata.obsm["X_pyropca"] = joint_pcs scv.pp.neighbors(adata, use_rep="pyropca") adata.layers["spliced_pyro"] = posterior_samples["st"].mean(0) if "u_scale" in posterior_samples: adata.layers["velocity_pyro"] = ( posterior_samples["ut"] * posterior_samples["beta"] / (posterior_samples["u_scale"] / posterior_samples["s_scale"]) - posterior_samples["st"] * posterior_samples["gamma"] ).mean(0) else: adata.layers["velocity_pyro"] = ( posterior_samples["ut"] * posterior_samples["beta"] - posterior_samples["st"] * posterior_samples["gamma"] ).mean(0) scv.tl.velocity_graph(adata, vkey="velocity_pyro", xkey="spliced_pyro") scv.tl.velocity_embedding(adata, vkey="velocity_pyro", basis="umap1") scv.pl.velocity_embedding_grid( adata, basis="umap2", vkey="velocity_pyro", color=cell_state, density=0.5, scale=0.25, arrow_size=3, show=False, ax=ax[1][1], )
[docs]def vector_field_uncertainty( adata: AnnData, posterior_samples: Dict[str, ndarray], basis: str = "tsne", n_jobs: int = 1, denoised: bool = False, ) -> Tuple[ndarray, ndarray, ndarray]: """Run cosine similarity-based vector field across posterior samples""" # fig, ax = plt.subplots(10, 3) # fig.set_size_inches(16, 36) # ax = ax.flatten() v_map_all = [] if ("u_scale" in posterior_samples) and ( "s_scale" in posterior_samples ): # Gaussian models scale = posterior_samples["u_scale"] / posterior_samples["s_scale"] elif ("u_scale" in posterior_samples) and not ( "s_scale" in posterior_samples ): # Poisson Model 2 scale = posterior_samples["u_scale"] else: # Poisson Model 1 scale = 1 if "beta_k" in posterior_samples: velocity_samples = ( posterior_samples["ut"] * posterior_samples["beta_k"] / scale - posterior_samples["st"] * posterior_samples["gamma_k"] ) else: velocity_samples = ( posterior_samples["beta"] * posterior_samples["ut"] / scale - posterior_samples["gamma"] * posterior_samples["st"] ) if denoised: projection = [ ("PCA", sklearn.decomposition.PCA(random_state=99, n_components=50)), ("UMAP", umap.UMAP(random_state=99, n_components=2)), ] pipelines = Pipeline(projection) expression = [posterior_samples["st"].mean(0)] pipelines.fit(expression[0]) umap_orig = pipelines.transform(expression[0]) adata.obsm["X_umap1"] = umap_orig joint_pcs = pipelines.steps[0][1].transform(expression[0]) adata.obsm["X_pyropca"] = joint_pcs scv.pp.neighbors(adata, use_rep="pyropca") else: scv.pp.neighbors(adata, use_rep="pca") ##scv.pp.neighbors(adata, use_rep=basis) assert len(posterior_samples["st"].shape) == 3 adata.var["velocity_genes"] = True for sample in range(posterior_samples["st"].shape[0]): adata.layers["spliced_pyro"] = posterior_samples["st"][sample] adata.layers["velocity_pyro"] = velocity_samples[sample] if basis == "pca": scv.pp.pca(adata) scv.tl.velocity_embedding( adata, vkey="velocity_pyro", basis="pca", direct_pca_projection=True ) else: scv.tl.velocity_graph( adata, vkey="velocity_pyro", xkey="spliced_pyro", n_jobs=n_jobs ) scv.tl.velocity_embedding(adata, vkey="velocity_pyro", basis=basis) v_map_all.append(adata.obsm[f"velocity_pyro_{basis}"]) v_map_all = np.stack(v_map_all) embeds_radian = np.arctan2(v_map_all[:, :, 1], v_map_all[:, :, 0]) from statsmodels.stats.multitest import multipletests rayleightest_pval = rayleightest(embeds_radian, axis=-2) _, fdri, _, _ = multipletests(rayleightest_pval, method="fdr_bh") return v_map_all, embeds_radian, fdri
[docs]def get_posterior_sample_angle_uncertainty(posterior_angles): from astropy import units as u from astropy.stats import circstd x_values = np.arange(360) y_values = [] n_samples = 100 method = "angular" for i in x_values: datac = np.linspace(0, i + 1, n_samples) * u.deg y_values.append(circstd(datac, method=method)) y_values = np.array(y_values) angle_std = circstd(posterior_angles * u.deg, method="angular", axis=0) y_values = y_values.reshape(-1, 1) return x_values[np.argmin(np.abs(y_values - angle_std), axis=0)]
[docs]def plot_vector_field_uncertain( adata, embed_mean, embeds_radian_or_magnitude, fig=None, cbar=True, basis="umap", scale=0.002, cbar_pos=[0.22, 0.28, 0.5, 0.05], p_mass_min=3.5, only_grid=False, ax=None, autoscale=False, density=0.3, arrow_size=5, uncertain_measure="angle", cmap="winter", cmax=0.305, ): if cmap == "inferno": colormap = cm.inferno elif cmap == "summer": colormap = cm.summer else: colormap = cm.winter print(adata.shape) print(embeds_radian_or_magnitude.shape) if uncertain_measure == "angle": adata.obs["uncertain"] = get_posterior_sample_angle_uncertainty( embeds_radian_or_magnitude / np.pi * 180 ) else: adata.obs["uncertain"] = embeds_radian_or_magnitude.std(axis=0) if uncertain_measure in ["base magnitude", "shared time", "PCA angle"]: adata.obs["uncertain"] = embeds_radian_or_magnitude dot_size = 1 # plt.rcParams["image.cmap"] = "winter" if ax is None: ax = fig.subplots(1, 2) if isinstance(ax, list) and len(ax) == 2: if not only_grid: # norm = Normalize() # norm.autoscale(adata.obs["uncertain"]) order = np.argsort(adata.obs["uncertain"].values) im = ax[0].scatter( adata.obsm[f"X_{basis}"][:, 0][order], adata.obsm[f"X_{basis}"][:, 1][order], # c=colormap(norm(adata.obs["uncertain"].values[order])), c=adata.obs["uncertain"].values[order], cmap=cmap, norm=None, vmin=np.percentile(uncertain, 5), vmax=np.percentile(uncertain, 95), s=dot_size, linewidth=1, edgecolors="face", ) ax[0].axis("off") ax[0].set_title( f"Single-cell\n {uncertain_measure} uncertainty ", fontsize=7 ) ax = ax[1] X_grid, V_grid, uncertain = project_grid_points( adata.obsm[f"X_{basis}"], embed_mean, adata.obs["uncertain"].values, p_mass_min=p_mass_min, autoscale=autoscale, density=density, ) # scale = None hl, hw, hal = default_arrow(arrow_size) quiver_kwargs = {"angles": "xy", "scale_units": "xy"} # quiver_kwargs = {"angles": "xy", "scale_units": "width"} quiver_kwargs.update({"width": 0.001, "headlength": hl / 2}) quiver_kwargs.update({"headwidth": hw / 2, "headaxislength": hal / 2}) quiver_kwargs.update({"linewidth": 1, "zorder": 3}) norm = Normalize() norm.autoscale(uncertain) ax.scatter( adata.obsm[f"X_{basis}"][:, 0], adata.obsm[f"X_{basis}"][:, 1], s=1, linewidth=0, color="gray", alpha=0.22, ) im = ax.quiver( X_grid[:, 0], X_grid[:, 1], V_grid[:, 0], V_grid[:, 1], uncertain, norm=None, cmap=cmap, edgecolors="face", scale=scale, clim=( np.percentile(uncertain, 5), np.percentile(uncertain, 95) if cmax is None else cmax, ), **quiver_kwargs, ) ax.set_title(f"Averaged\n {uncertain_measure} uncertainty ", fontsize=7) ax.axis("off") if cbar: # divider = make_axes_locatable(ax) # cax = divider.append_axes('bottom', size='5%', pad=0.1) # cbar = fig.colorbar(im, cax=cax, orientation="horizontal", shrink=0.6) ### cbar.ax.set_xticks([0, 180, 360], [0, 180, 360]) ##fig.colorbar(im, ax=ax, shrink=0.6, location='bottom') pos = ax.get_position() cbar_ax = fig.add_axes( [pos.x0 + 0.05, pos.y0 - 0.02, pos.width * 0.6, pos.height / 17] ) cbar = fig.colorbar( im, cax=cbar_ax, orientation="horizontal" ) # fraction=0.046, pad=0.04 cbar.ax.tick_params(axis="x", labelsize=5.5) cbar.ax.locator = MaxNLocator(nbins=2, integer=True)
# cbar.ax.set_xlabel(f"{uncertain_measure} uncertainty", fontsize=7)
[docs]def compute_mean_vector_field( posterior_samples, adata, basis="umap", n_jobs=1, spliced="spliced_pyro", raw=False, ): scv.pp.neighbors(adata, use_rep="pca") adata.var["velocity_genes"] = True if spliced == "spliced_pyro": if raw: ut = posterior_samples["ut"] st = posterior_samples["st"] ut = ut / ut.sum(axis=-1, keepdims=True) st = st / st.sum(axis=-1, keepdims=True) else: ut = posterior_samples["ut"] st = posterior_samples["st"] adata.layers["spliced_pyro"] = st.mean(0).squeeze() # if ('u_scale' in posterior_samples) and ('s_scale' in posterior_samples): # TODO: two scale for Normal distribution if "u_scale" in posterior_samples: # only one scale for Poisson distribution adata.layers["velocity_pyro"] = ( ut * posterior_samples["beta"] / posterior_samples["u_scale"] - st * posterior_samples["gamma"] ).mean(0) else: if "beta_k" in posterior_samples: adata.layers["velocity_pyro"] = ( ( ut * posterior_samples["beta_k"] - posterior_samples["st"] * posterior_samples["gamma_k"] ) .mean(0) .squeeze() ) else: adata.layers["velocity_pyro"] = ( ut * posterior_samples["beta"] - posterior_samples["st"] * posterior_samples["gamma"] ).mean(0) scv.tl.velocity_graph( adata, vkey="velocity_pyro", xkey="spliced_pyro", n_jobs=n_jobs ) elif spliced in ["Ms"]: ut = adata.layers["Mu"] st = adata.layers["Ms"] if ("u_scale" in posterior_samples) and ("s_scale" in posterior_samples): adata.layers["velocity_pyro"] = ( ut * posterior_samples["beta"] / (posterior_samples["u_scale"] / posterior_samples["s_scale"]) - st * posterior_samples["gamma"] ).mean(0) else: adata.layers["velocity_pyro"] = ( ut * posterior_samples["beta"] - posterior_samples["st"] * posterior_samples["gamma"] ).mean(0) scv.tl.velocity_graph(adata, vkey="velocity_pyro", xkey="Ms", n_jobs=n_jobs) elif spliced in ["spliced"]: ut = adata.layers["unspliced"] st = adata.layers["spliced"] if ("u_scale" in posterior_samples) and ("s_scale" in posterior_samples): adata.layers["velocity_pyro"] = ( ut * posterior_samples["beta"] / (posterior_samples["u_scale"] / posterior_samples["s_scale"]) - st * posterior_samples["gamma"] ).mean(0) else: adata.layers["velocity_pyro"] = ( ut * posterior_samples["beta"] - posterior_samples["st"] * posterior_samples["gamma"] ).mean(0) scv.tl.velocity_graph( adata, vkey="velocity_pyro", xkey="spliced", n_jobs=n_jobs ) scv.tl.velocity_embedding(adata, vkey="velocity_pyro", basis=basis)
[docs]def plot_mean_vector_field( posterior_samples, adata, ax, basis="umap", n_jobs=1, scale=0.2, density=0.4, spliced="spliced_pyro", raw=False, ): compute_mean_vector_field( posterior_samples=posterior_samples, adata=adata, basis=basis, n_jobs=n_jobs, spliced=spliced, raw=raw, ) scv.pl.velocity_embedding_grid( adata, basis=basis, vkey="velocity_pyro", linewidth=1, ax=ax, show=False, legend_loc="on data", density=density, scale=scale, arrow_size=3, ) return adata.obsm[f"velocity_pyro_{basis}"]
# def project_grid_points(emb, velocity_emb, uncertain=None, p_mass_min=3.5, density=0.3):
[docs]def project_grid_points( emb, velocity_emb, uncertain=None, p_mass_min=1.0, density=0.3, autoscale=False ): from scipy.stats import norm as normal from scvelo.tools.velocity_embedding import quiver_autoscale from sklearn.neighbors import NearestNeighbors X_grid = [] grs = [] grid_num = 50 * density smooth = 0.5 for dim_i in range(2): m, M = np.min(emb[:, dim_i]), np.max(emb[:, dim_i]) # m = m - .025 * np.abs(M - m) # M = M + .025 * np.abs(M - m) m = m - 0.01 * np.abs(M - m) M = M + 0.01 * np.abs(M - m) gr = np.linspace(m, M, int(grid_num)) grs.append(gr) meshes_tuple = np.meshgrid(*grs) scale = np.mean([(g[1] - g[0]) for g in grs]) * smooth X_grid = np.vstack([i.flat for i in meshes_tuple]).T n_neighbors = int(emb.shape[0] / 50) print(n_neighbors) # nn = NearestNeighbors(n_neighbors=30, n_jobs=-1) nn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1) nn.fit(emb) dists, neighs = nn.kneighbors(X_grid) weight = normal.pdf(x=dists, scale=scale) p_mass = weight.sum(1) if len(velocity_emb.shape) == 2: V_grid = (velocity_emb[:, :2][neighs] * weight[:, :, None]).sum(1) / np.maximum( 1, p_mass )[:, None] else: V_grid = (velocity_emb[:, :2][neighs] * weight[:, :, None, None]).sum( 1 ) / np.maximum(1, p_mass)[:, None, None] print(V_grid.shape) p_mass_min *= np.percentile(p_mass, 99) / 100 if autoscale: V_grid /= 3 * quiver_autoscale(X_grid, V_grid) if uncertain is None: return X_grid[p_mass > p_mass_min], V_grid[p_mass > p_mass_min] else: uncertain = (uncertain[neighs] * weight).sum(1) / np.maximum(1, p_mass) return ( X_grid[p_mass > p_mass_min], V_grid[p_mass > p_mass_min], uncertain[p_mass > p_mass_min], )
[docs]def plot_arrow_examples( adata, v_maps, embeds_radian, embed_mean, ax=None, fig=None, cbar=True, basis="umap", n_sample=30, scale=0.0021, alpha=0.02, index=19, index2=0, scale2=0.04, num_certain=3, num_total=4, p_mass_min=1.0, density=0.3, arrow_size=4, customize_uncertain=None, ): X_grid, V_grid, uncertain = project_grid_points( adata.obsm[f"X_{basis}"], v_maps, get_posterior_sample_angle_uncertainty(embeds_radian / np.pi * 180) if customize_uncertain is None else customize_uncertain, p_mass_min=p_mass_min, density=density, ) print(X_grid.shape, V_grid.shape, uncertain.shape) norm = Normalize() norm.autoscale(uncertain) colormap = cm.inferno indexes = np.argsort(uncertain)[::-1][index : (index + num_total - num_certain)] hl, hw, hal = default_arrow(arrow_size) print(hl, hw, hal) quiver_kwargs = {"angles": "xy", "scale_units": "xy"} # quiver_kwargs = {"angles": "xy", "scale_units": "width"} quiver_kwargs = {"width": 0.002, "zorder": 0} quiver_kwargs.update({"headlength": hl / 2}) quiver_kwargs.update({"headwidth": hw / 2, "headaxislength": hal / 2}) ax.scatter( adata.obsm[f"X_{basis}"][:, 0], adata.obsm[f"X_{basis}"][:, 1], s=1, linewidth=0, color="gray", alpha=alpha, ) # normalize arrow size the constant V_grid[:, 0] = V_grid[:, 0] / np.sqrt(V_grid[:, 0] ** 2 + V_grid[:, 1] ** 2) V_grid[:, 1] = V_grid[:, 1] / np.sqrt(V_grid[:, 1] ** 2 + V_grid[:, 1] ** 2) for i in range(n_sample): for j in indexes: # ax.quiver( # X_grid[j, 0], # X_grid[j, 1], # embed_mean[j, 0], # embed_mean[j, 1], # ec='black', # scale=scale, # color=colormap(norm(uncertain))[j], # **quiver_kwargs, # ) ax.quiver( X_grid[j, 0], X_grid[j, 1], V_grid[j][0][i], V_grid[j][1][i], ec="face", norm=Normalize(vmin=0, vmax=360), scale=scale, color=colormap(norm(uncertain))[j], linewidth=0, alpha=0.3, **quiver_kwargs, ) ax.quiver( X_grid[j, 0], X_grid[j, 1], V_grid[j][0].mean(), V_grid[j][1].mean(), ec="black", alpha=1, norm=Normalize(vmin=0, vmax=360), scale=scale, linewidth=0, color=colormap(norm(uncertain))[j], **quiver_kwargs, ) indexes = np.argsort(uncertain)[index2 : (index2 + num_certain)] for i in range(n_sample): for j in indexes: # ax.quiver( # X_grid[j, 0], # X_grid[j, 1], # embed_mean[j, 0], # embed_mean[j, 1], # ec='black', # scale=scale, # color=colormap(norm(uncertain))[j], # **quiver_kwargs, # ) ax.quiver( X_grid[j, 0], X_grid[j, 1], V_grid[j][0][i], V_grid[j][1][i], # ec=colormap(norm(uncertain))[j], ec="face", scale=scale2, alpha=0.3, linewidth=0, color=colormap(norm(uncertain))[j], norm=Normalize(vmin=0, vmax=360), **quiver_kwargs, ) ax.quiver( X_grid[j, 0], X_grid[j, 1], V_grid[j][0].mean(), V_grid[j][1].mean(), ec="black", alpha=1, linewidth=0, norm=Normalize(vmin=0, vmax=360), scale=scale2, color=colormap(norm(uncertain))[j], **quiver_kwargs, ) ax.axis("off")
[docs]def set_colorbar( smp, ax, orientation="vertical", labelsize=None, fig=None, position="right", rainbow=False, ): from matplotlib.ticker import MaxNLocator from mpl_toolkits.axes_grid1.inset_locator import inset_axes if position == "right" and (not rainbow): cax = inset_axes(ax, width="2%", height="30%", loc=4, borderpad=0) cb = fig.colorbar(smp, orientation=orientation, cax=cax) else: # cax = inset_axes(ax, width="20%", height="90%", loc=4, borderpad=0) # cb = fig.colorbar(smp, orientation=orientation, cax=cax) divider = make_axes_locatable(ax) cax = divider.append_axes(position, size="8%", pad=0.08) cb = fig.colorbar(smp, cax=cax, orientation=orientation, shrink=0.4) cb.ax.tick_params(labelsize=labelsize) cb.set_alpha(1) cb.draw_all() cb.locator = MaxNLocator(nbins=2, integer=True) if position == "left": cb.ax.yaxis.set_ticks_position("left") cb.update_ticks()
[docs]def us_rainbowplot( genes: pd.Index, adata: AnnData, posterior_samples: Dict[str, ndarray], data: List[str] = ["st", "ut"], cell_state: str = "clusters", ) -> Figure: import matplotlib.lines as mlines fig, ax = plt.subplots(len(genes), 2) fig.set_size_inches(7, 14) n = 0 if data[0] in posterior_samples: pos_s = posterior_samples[data[0]].mean(0).squeeze() pos_u = posterior_samples[data[1]].mean(0).squeeze() else: pos_u = posterior_samples["ut_mean"] pos_s = posterior_samples["st_mean"] for gene in genes: (index,) = np.where(adata.var_names == gene) ax1 = ax[n, 1] if n == 0: ax1.set_title("Rainbow plot") ress = pd.DataFrame( { "cell_time": posterior_samples["cell_time"].mean(0).squeeze(), "cell_type": adata.obs[cell_state].values, "spliced": pos_s[:, index].squeeze(), "unspliced": pos_u[:, index].squeeze(), } ) if n == 2: sns.scatterplot( x="cell_time", y="spliced", data=ress, alpha=0.4, linewidth=0, edgecolor="none", hue="cell_type", ax=ax1, marker="o", legend="brief", palette="bright", s=10, ) else: sns.scatterplot( x="cell_time", y="spliced", data=ress, alpha=0.4, linewidth=0, edgecolor="none", palette="bright", hue="cell_type", ax=ax1, marker="o", legend="brief", s=10, ) ax2 = ax[n, 0] ax2.set_title(gene) ax2.set_ylabel("") ax2.set_xlabel("") sns.scatterplot( x="spliced", y="unspliced", data=ress, alpha=0.4, s=25, edgecolor="none", hue="cell_type", ax=ax2, legend=False, marker="*", palette="bright", ) if n == 3: blue_star = mlines.Line2D( [], [], color="black", marker="o", linestyle="None", markersize=5, label="Spliced", ) red_square = mlines.Line2D( [], [], color="black", marker="+", linestyle="None", markersize=5, label="Unspliced", ) ax1.legend(handles=[blue_star, red_square], bbox_to_anchor=[2, -0.03]) ax1.set_xlabel("") n += 1 ax1.legend(bbox_to_anchor=[2, 0.1]) ax1.tick_params(labelbottom=True) ax2.set_xlabel("spliced") ax2.set_title(gene) plt.subplots_adjust(hspace=0.8, wspace=0.6, left=0.1, right=0.91) return fig
[docs]def rainbowplot( volcano_data, adata, posterior_samples, fig=None, genes=None, data=["st", "ut"], cell_state="clusters", basis="umap", num_genes=5, add_line=True, negative=False, scvelo_colors=False, ) -> None: matplotlib.rcParams.update({"font.size": 7}) if genes is None: genes = ( volcano_data.sort_values("mean_mae", ascending=False) .head(300) .sort_values("time_correlation", ascending=negative) .head(num_genes) .index ) if fig is None: fig = plt.figure(figsize=(5.5, 4.5)) if scvelo_colors: scv.pl.scatter( adata, basis=basis, fontsize=7, legend_loc="on data", legend_fontsize=7, color=cell_state, show=False, ) colors = dict( zip(adata.obs.state_info.cat.categories, adata.uns["state_info_colors"]) ) else: clusters = adata.obs.loc[:, cell_state] colors = dict( zip( clusters.cat.categories, sns.color_palette("deep", clusters.cat.categories.shape[0]), ) ) subfigs = fig.subfigures(1, 2, wspace=0.0, width_ratios=[3, 1.5]) ax = subfigs[0].subplots(len(genes), 2) ax_fig2 = subfigs[1].subplots(len(genes), 1) n = 0 if (data[0] in posterior_samples) and (data[1] in posterior_samples): st = posterior_samples[data[0]].mean(0).squeeze() ut = posterior_samples[data[1]].mean(0).squeeze() else: st = posterior_samples["st_mean"] ut = posterior_samples["ut_mean"] for gene in genes: (index,) = np.where(adata.var_names == gene) ax1 = ax[n, 1] ax2 = ax[n, 0] ax3 = ax_fig2[n] if n == 0: ax1.set_title("Rainbow plot", fontsize=7) ax2.set_title("Phase portrait", fontsize=7) pos_mean_time = posterior_samples["cell_time"].mean(0).flatten() ress = pd.DataFrame( { "cell_time": pos_mean_time / pos_mean_time.max(), "cell_type": adata.obs[cell_state].values, "spliced": st[:, index].flatten(), "unspliced": ut[:, index].flatten(), } ) sns.scatterplot( x="cell_time", y="spliced", data=ress, alpha=0.1, linewidth=0, edgecolor="none", hue="cell_type", palette=colors, ax=ax1, marker="o", legend=False, s=5, ) if add_line: ress = ress.sort_values("cell_time") for row in range(ress.shape[0]): ax1.vlines( x=ress.cell_time[row], ymin=0, ymax=ress.spliced[row], colors=colors[ress.cell_type[row]], alpha=0.1, ) if n == len(genes) - 1: ax1.set_xlabel("shared time", fontsize=7) else: ax1.set_xlabel("") ax1.set_ylabel("") t = [0, round(ress["cell_time"].max(), 5)] t_label = ["0", "%.1E" % ress["cell_time"].max()] ax1.set_xticks(t, t_label, fontsize=7) t = [0, round(ress["spliced"].max(), 5)] t_label = ["0", "%.1E" % ress["spliced"].max()] ax1.set_yticks(t, t_label, fontsize=7) ress = pd.DataFrame( { "cell_type": adata.obs[cell_state].values, "unspliced": ut[:, index].flatten(), "spliced": st[:, index].flatten(), } ) sns.scatterplot( x="spliced", y="unspliced", data=ress, alpha=0.4, linewidth=0, edgecolor="none", hue="cell_type", palette=colors, ax=ax2, marker="o", legend=False, s=3, ) ax2.set_xlabel("") ax2.set_ylabel(gene, fontsize=7, rotation=0, labelpad=23) if n == len(genes) - 1: ax2.set_xlabel("spliced", fontsize=7) t = [0, round(ress["unspliced"].max(), 5)] t_label = ["0", "%.1E" % ress["unspliced"].max()] ax2.set_yticks(t, t_label, fontsize=7) t = [0, round(ress["spliced"].max(), 5)] t_label = ["0", "%.1E" % ress["spliced"].max()] ax2.set_xticks(t, t_label, fontsize=7) im = ax3.scatter( adata.obsm[f"X_{basis}"][:, 0], adata.obsm[f"X_{basis}"][:, 1], s=3, c=st[:, index].flatten(), cmap="RdBu_r", ) set_colorbar(im, ax3, labelsize=5, fig=subfigs[1], rainbow=True) ax3.axis("off") if n == 0: ax3.set_title("Denoised spliced", fontsize=7) n += 1 sns.despine() subfigs[0].subplots_adjust( hspace=0.8, wspace=1.4, left=0.32, right=0.94, top=0.92, bottom=0.12 ) subfigs[1].subplots_adjust( hspace=0.8, wspace=0.4, left=0.2, right=0.7, top=0.92, bottom=0.08 ) subfigs[0].text( -0.025, 0.58, "unspliced expression", size=7, rotation="vertical", va="center" ) subfigs[0].text( 0.552, 0.58, "spliced expression", size=7, rotation="vertical", va="center" ) return fig
[docs]def plot_state_uncertainty( posterior_samples, adata, kde=True, data="denoised", top_percentile=0.9, ax=None, basis="umap", ): if data == "denoised": adata.obs["state_uncertain"] = np.sqrt( ( (posterior_samples["st"] - posterior_samples["st"].mean(0)) ** 2 + (posterior_samples["ut"] - posterior_samples["ut"].mean(0)) ** 2 ).sum(-1) ).mean(0) else: adata.obs["state_uncertain"] = np.sqrt( ( (posterior_samples["s"] - posterior_samples["s"].mean(0)) ** 2 + (posterior_samples["u"] - posterior_samples["u"].mean(0)) ** 2 ).sum(-1) ).mean(0) ax = scv.pl.scatter( adata, basis=basis, color="state_uncertain", cmap="RdBu_r", ax=ax, show=False, colorbar=True, fontsize=7, ) if kde: select = adata.obs["state_uncertain"] > np.quantile( adata.obs["state_uncertain"], top_percentile ) sns.kdeplot( adata.obsm[f"X_{basis}"][:, 0][select], adata.obsm[f"X_{basis}"][:, 1][select], ax=ax, levels=3, fill=False, ) else: select = None return select, ax
import anndata from scipy.sparse import issparse
[docs]def get_clone_trajectory( adata, average_start_point=True, global_traj=True, times=[2, 4, 6], clone_num=None ): if not average_start_point: adata.obsm["clone_vector_emb"] = np.zeros((adata.shape[0], 2)) adatas = [] clones = [] centroids = [] cen_clones = [] print(adata.shape) adata.obs["clones"] = 0 if "noWell" in adata.obs.columns: for w in adata.obs.Well.unique(): adata_w = adata[adata.obs.Well == w] clone_adata_w = clone_adata[clone_adata.obs.Well == w] for j in range(clone_adata_w.shape[1]): adata_w.obs["clonei"] = 0 # belongs to same clone adata_w.obs.loc[ clone_adata_w[:, j].X.toarray()[:, 0] >= 1, "clonei" ] = 1 if not average_start_point: for i in np.where( (adata_w.obs.time == 2) & (adata_w.obs.clonei == 1) )[0]: next_time = np.where( (adata_w.obs.time == 4) & (adata_w.obs.clonei == 1) )[0] adata_w.obsm["velocity_umap"][i] = ( adata_w.obsm["X_umap"][next_time].mean(axis=0) - adata_w.obsm["X_umap"][i] ) for i in np.where( (adata_w.obs.time == 4) & (adata_w.obs.clonei == 1) )[0]: next_time = np.where( (adata_w.obs.time == 6) & (adata_w.obs.clonei == 1) )[0] adata_w.obsm["velocity_umap"][i] = ( adata_w.obsm["X_umap"][next_time].mean(axis=0) - adata_w.obsm["X_umap"][i] ) else: time2 = np.where( (adata_w.obs.time == 2) & (adata_w.obs.clonei == 1) )[0] time4 = np.where( (adata_w.obs.time == 4) & (adata_w.obs.clonei == 1) )[0] time6 = np.where( (adata_w.obs.time == 6) & (adata_w.obs.clonei == 1) )[0] if ( time2.shape[0] == 0 and time4.shape[0] == 0 and time6.shape[0] == 0 ): continue if ( time2.shape[0] > 0 and time4.shape[0] == 0 and time6.shape[0] > 0 ): continue adata_new = anndata.AnnData( np.vstack( [ adata_w[time2].X.toarray().mean(axis=0), adata_w[time4].X.toarray().mean(axis=0), adata_w[time6].X.toarray().mean(axis=0), ] ), layers={ "spliced": np.vstack( [ adata_w[time2] .layers["spliced"] .toarray() .mean(axis=0), adata_w[time4] .layers["spliced"] .toarray() .mean(axis=0), adata_w[time6] .layers["spliced"] .toarray() .mean(axis=0), ] ), "unspliced": np.vstack( [ adata_w[time2] .layers["unspliced"] .toarray() .mean(axis=0), adata_w[time4] .layers["unspliced"] .toarray() .mean(axis=0), adata_w[time6] .layers["unspliced"] .toarray() .mean(axis=0), ] ), }, var=adata_w.var, ) adata_new.obs.loc[:, "time"] = [2, 4, 6] adata_new.obs.loc[:, "Cell type annotation"] = "Centroid" print(adata_w[time6].obs.clonetype.unique()) print(adata_w[time6].obs) adata_new.obs.loc[:, "clonetype"] = adata_w[ time6 ].obs.clonetype.unique() # use cell fate from last time point adata_new.obs.loc[:, "clones"] = int(j) if "Well" in adata_w[time6].obs.columns: adata_new.obs.loc[:, "Well"] = adata_w[time6].obs.Well.unique() adata_new.obsm["X_umap"] = np.vstack( [ adata_w[time2].obsm["X_umap"].mean(axis=0), adata_w[time4].obsm["X_umap"].mean(axis=0), adata_w[time6].obsm["X_umap"].mean(axis=0), ] ) adata_new.obsm["velocity_umap"] = np.vstack( [ adata_w.obsm["X_umap"][time4].mean(axis=0) - adata_w.obsm["X_umap"][time2].mean(axis=0), adata_w.obsm["X_umap"][time6].mean(axis=0) - adata_w.obsm["X_umap"][time4].mean(axis=0), np.zeros(2), ] ) centroids.append(adata_new) clone_new = anndata.AnnData( np.vstack( [ clone_adata_w[time2].X.toarray().mean(axis=0), clone_adata_w[time4].X.toarray().mean(axis=0), clone_adata_w[time6].X.toarray().mean(axis=0), ] ), obs=adata_new.obs, ) clone_new.var_names = clone_adata.var_names clone_new.var = clone_adata.var # print(clone_new.shape) cen_clones.append(clone_new) adata_new = adata_w.concatenate( centroids[0].concatenate(centroids[1:]), join="outer" ) clone_new = clone_adata_w.concatenate( cen_clones[0].concatenate(cen_clones[1:]), join="outer" ) adatas.append(adata_new) clones.append(clone_new) return adatas[0].concatenate(adatas[1]), clones[0].concatenate(clones[1]) else: if clone_num is None: clone_num = adata.obsm["X_clone"].shape[1] for j in range(clone_num): print(j) adata.obs["clonei"] = 0 # print('----------aa------') if issparse(adata.obsm["X_clone"]): adata.obs.loc[adata.obsm["X_clone"].toarray()[:, j] >= 1, "clonei"] = 1 else: adata.obs.loc[adata.obsm["X_clone"][:, j] >= 1, "clonei"] = 1 # print('----------bb------') if not average_start_point: for i in np.where((adata.obs.time == 2) & (adata.obs.clonei == 1))[0]: next_time = np.where( (adata.obs.time == 4) & (adata.obs.clonei == 1) )[0] adata.obsm["velocity_umap"][i] = ( adata.obsm["X_umap"][next_time].mean(axis=0) - adata.obsm["X_umap"][i] ) for i in np.where((adata.obs.time == 4) & (adata.obs.clonei == 1))[0]: next_time = np.where( (adata.obs.time == 6) & (adata.obs.clonei == 1) )[0] adata.obsm["velocity_umap"][i] = ( adata.obsm["X_umap"][next_time].mean(axis=0) - adata.obsm["X_umap"][i] ) else: if global_traj: times_index = [] for t in times: times_index.append( np.where( (adata.obs.time_info == t) & (adata.obs.clonei == 1) )[0] ) consecutive_flag = np.array( [int(time.shape[0] > 0) for time in times_index] ) consecutive = np.diff(consecutive_flag) if np.sum(consecutive_flag == 1) >= 2 and np.any( consecutive == 0 ): # Must be consecutive time points # print('centroid:', consecutive, times_index) adata_new = anndata.AnnData( np.vstack( [ np.array(adata[time].X.mean(axis=0)).squeeze() for time in times_index if time.shape[0] > 0 ] ), # layers={'spliced': # np.vstack([np.array(adata[time].layers['spliced'].mean(axis=0)) for time in times_index if time.shape[0] > 0]), # 'unspliced': # np.vstack([np.array(adata[time].layers['unspliced'].mean(axis=0)) for time in times_index if time.shape[0] > 0]) # }, var=adata.var, ) # print('----------cc------') adata.obs.iloc[ np.hstack( [time for time in times_index if time.shape[0] > 0] ), adata.obs.columns.get_loc("clones"), ] = int(j) adata_new.obs.loc[:, "time"] = [ t for t, time in zip([2, 4, 6], times_index) if time.shape[0] > 0 ] adata_new.obs.loc[:, "clones"] = int(j) adata_new.obs.loc[:, "state_info"] = "Centroid" adata_new.obsm["X_emb"] = np.vstack( [ adata[time].obsm["X_emb"].mean(axis=0) for time in times_index if time.shape[0] > 0 ] ) # print('----------dd------') # print(adata_new.shape) # print(adata_new.obsm['X_umap']) adata_new.obsm["clone_vector_emb"] = np.vstack( [ adata_new.obsm["X_emb"][i + 1] - adata_new.obsm["X_emb"][i] for i in range(adata_new.obsm["X_emb"].shape[0] - 1) ] + [np.zeros(2)] ) # print('----------ee------') # print(adata_new.obsm['clone_vector_emb']) else: # print('pass-------') continue else: time2 = np.where((adata.obs.time == t) & (adata.obs.clonei == 1))[0] time4 = np.where((adata.obs.time == 4) & (adata.obs.clonei == 1))[0] time6 = np.where((adata.obs.time == 6) & (adata.obs.clonei == 1))[0] adata_new = anndata.AnnData( np.vstack( [ adata[time2].X.toarray().mean(axis=0), adata[time4].X.toarray().mean(axis=0), adata[time6].X.toarray().mean(axis=0), ] ), layers={ "spliced": np.vstack( [ adata[time2] .layers["spliced"] .toarray() .mean(axis=0), adata[time4] .layers["spliced"] .toarray() .mean(axis=0), adata[time6] .layers["spliced"] .toarray() .mean(axis=0), ] ), "unspliced": np.vstack( [ adata[time2] .layers["unspliced"] .toarray() .mean(axis=0), adata[time4] .layers["unspliced"] .toarray() .mean(axis=0), adata[time6] .layers["unspliced"] .toarray() .mean(axis=0), ] ), }, var=adata.var, ) print(adata_new.X.sum(axis=1)) adata_new.obs.loc[:, "time"] = [2, 4, 6] adata_new.obs.loc[:, "Cell type annotation"] = "Centroid" if not global_traj: adata_new.obs.loc[:, "clonetype"] = adata[ time6 ].obs.clonetype.unique() # use cell fate from last time point adata_new.obs.loc[:, "clones"] = j if "noWell" in adata[time6].obs.columns: adata_new.obs.loc[:, "Well"] = adata[time6].obs.Well.unique() adata_new.obsm["X_umap"] = np.vstack( [ adata[time2].obsm["X_umap"].mean(axis=0), adata[time4].obsm["X_umap"].mean(axis=0), adata[time6].obsm["X_umap"].mean(axis=0), ] ) adata_new.obsm["velocity_umap"] = np.vstack( [ adata.obsm["X_umap"][time4].mean(axis=0) - adata.obsm["X_umap"][time2].mean(axis=0), adata.obsm["X_umap"][time6].mean(axis=0) - adata.obsm["X_umap"][time4].mean(axis=0), np.zeros(2), ] ) # print(adata_new.obsm['velocity_umap']) clone_new = anndata.AnnData( np.vstack( [ clone_adata[time2].X.toarray().mean(axis=0), clone_adata[time4].X.toarray().mean(axis=0), clone_adata[time6].X.toarray().mean(axis=0), ] ), obs=adata_new.obs, ) clone_new.var_names = clone_adata.var_names clone_new.var = clone_adata.var cen_clones.append(clone_new) centroids.append(adata_new) print(adata.shape) print(len(centroids)) adata_new = adata.concatenate( centroids[0].concatenate(centroids[1:]), join="outer" ) return adata_new
[docs]def align_trajectory_diff( adatas, velocity_embeds, density=0.3, smooth=0.5, input_grid=None, input_scale=None, min_mass=1.0, embed="umap", autoscale=False, length_cutoff=10, ): from scipy.stats import norm as normal from scvelo.tools.velocity_embedding import quiver_autoscale from sklearn.neighbors import NearestNeighbors if input_grid is None and input_scale is None: grs = [] # align embedding points into shared grid across adata X_emb = np.vstack([a.obsm[f"X_{embed}"] for a in adatas]) for dim_i in range(2): m, M = np.min(X_emb[:, dim_i]), np.max(X_emb[:, dim_i]) m = m - 0.01 * np.abs(M - m) M = M + 0.01 * np.abs(M - m) gr = np.linspace(m, M, int(50 * density)) grs.append(gr) meshes_tuple = np.meshgrid(*grs) scale = np.mean([(g[1] - g[0]) for g in grs]) * smooth X_grid = np.vstack([i.flat for i in meshes_tuple]).T else: scale = input_scale X_grid = input_grid n_neighbors = int(max([a.shape[0] for a in adatas]) / 50) results = [X_grid] p_mass_list = [] for adata, velocity_embed in zip(adatas, velocity_embeds): nn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1) nn.fit(adata.obsm[f"X_{embed}"]) dists, neighs = nn.kneighbors(X_grid) weight = normal.pdf(x=dists, scale=scale) # how many cells around a grid points p_mass = weight.sum(1) V_grid = (velocity_embed[neighs] * weight[:, :, None]).sum(1) / np.maximum( 1, p_mass )[:, None] if autoscale: V_grid /= 3 * quiver_autoscale(X_grid, V_grid) results.append(V_grid) p_mass_list.append(p_mass) from functools import reduce if input_grid is None and input_scale is None: min_mass *= np.percentile(np.hstack(p_mass_list), 99) / 100 mass_index = reduce( np.intersect1d, [np.where(p_mass > min_mass)[0] for p_mass in p_mass_list] ) results = np.hstack(results) results = results[mass_index] print(results.shape) length_filter = np.sqrt((results[:, 2:4] ** 2).sum(1)) > length_cutoff return results[length_filter]