Source code for pyrovelocity.api

"""Model training API for Pyro-Velocity."""
from typing import Dict
from typing import Optional
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from anndata._core.anndata import AnnData
from numpy import ndarray
from pyro import poutine
from pyro.infer.autoguide import AutoNormal
from pyro.infer.autoguide.guides import AutoGuideList
from sklearn.model_selection import train_test_split

from pyrovelocity._velocity import PyroVelocity


[docs]def train_model( adata: AnnData, guide_type: str = "auto", model_type: str = "auto", svi_train: bool = False, # svi_train alreadys turn off batch_size: int = -1, train_size: float = 1.0, use_gpu: int = 0, likelihood: str = "Poisson", num_samples: int = 30, log_every: int = 100, cell_state: str = "clusters", patient_improve: float = 5e-4, patient_init: int = 30, seed: int = 99, lr: float = 0.01, max_epochs: int = 3000, include_prior: bool = True, library_size: bool = True, offset: bool = False, input_type: str = "raw", cell_specific_kinetics: Optional[str] = None, kinetics_num: int = 2, loss_plot_path: str = "loss_plot.png", ) -> Tuple[PyroVelocity, Dict[str, ndarray]]: """ Train a PyroVelocity model to provide probabilistic estimates of RNA velocity for single-cell RNA sequencing data with quantified splice variants. Args: adata (AnnData): An AnnData object containing the input data. guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto". model_type (str, optional): The type of Pyro model. Default is "auto". svi_train (bool, optional): Whether to use Stochastic Variational Inference for training. Default is False. batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset. train_size (float, optional): Proportion of data to be used for training. Default is 1.0. use_gpu (int, optional): Whether to use GPU for training. Default is 0, which indicates not using GPU. likelihood (str, optional): Likelihood function for the Pyro model. Default is "Poisson". num_samples (int, optional): Number of posterior samples. Default is 30. log_every (int, optional): Frequency of logging progress. Default is 100. cell_state (str, optional): Cell state attribute in the AnnData object. Default is "clusters". patient_improve (float, optional): Minimum improvement in training loss for early stopping. Default is 5e-4. patient_init (int, optional): Number of initial training epochs before early stopping is enabled. Default is 30. seed (int, optional): Random seed for reproducibility. Default is 99. lr (float, optional): Learning rate for the optimizer. Default is 0.01. max_epochs (int, optional): Maximum number of training epochs. Default is 3000. include_prior (bool, optional): Whether to include prior information in the model. Default is True. library_size (bool, optional): Whether to correct for library size. Default is True. offset (bool, optional): Whether to add an offset to the model. Default is False. input_type (str, optional): Type of input data. Default is "raw". cell_specific_kinetics (Optional[str], optional): Name of the attribute containing cell-specific kinetics information. Default is None. kinetics_num (int, optional): Number of kinetics parameters. Default is 2. loss_plot_path (str, optional): Path to save the loss plot. Default is "loss_plot.png". Returns: Tuple[PyroVelocity, Dict[str, ndarray]]: A tuple containing the trained PyroVelocity model and a dictionary of posterior samples. Examples: >>> from pyrovelocity.api import train_model >>> from pyrovelocity.utils import generate_sample_data >>> from pyrovelocity.data import copy_raw_counts >>> adata = generate_sample_data(random_seed=99) >>> copy_raw_counts(adata) >>> model, posterior_samples = train_model(adata, seed=99, max_epochs=200, loss_plot_path="loss_plot_docs.png") """ PyroVelocity.setup_anndata(adata) model = PyroVelocity( adata, likelihood=likelihood, model_type=model_type, guide_type=guide_type, correct_library_size=library_size, add_offset=offset, include_prior=include_prior, input_type=input_type, cell_specific_kinetics=cell_specific_kinetics, kinetics_num=kinetics_num, ) if svi_train and guide_type in { "velocity_auto", "velocity_auto_t0_constraint", }: if batch_size == -1: batch_size = adata.shape[0] model.train( max_epochs=max_epochs, lr=lr, use_gpu=use_gpu, batch_size=batch_size, train_size=train_size, valid_size=1 - train_size, check_val_every_n_epoch=1, early_stopping=True, patience=patient_init, min_delta=patient_improve, ) fig, ax = plt.subplots() fig.set_size_inches(2.5, 1.5) ax.scatter( model.history_["elbo_train"].index[:-1], -model.history_["elbo_train"][:-1], label="Train", ) if train_size < 1: ax.scatter( model.history_["elbo_validation"].index[:-1], -model.history_["elbo_validation"][:-1], label="Valid", ) set_loss_plot_axes(ax) fig.savefig(loss_plot_path, facecolor="white", bbox_inches="tight") posterior_samples = model.generate_posterior_samples( model.adata, num_samples=num_samples, batch_size=512 ) return model, posterior_samples else: if train_size >= 1: ##support velocity_auto_depth if batch_size == -1: batch_size = adata.shape[0] if batch_size >= adata.shape[0]: losses = model.train_faster( max_epochs=max_epochs, lr=lr, use_gpu=use_gpu, seed=seed, patient_improve=patient_improve, patient_init=patient_init, log_every=log_every, ) else: losses = model.train_faster_with_batch( max_epochs=max_epochs, batch_size=batch_size, log_every=log_every, lr=lr, use_gpu=use_gpu, seed=seed, patient_improve=patient_improve, patient_init=patient_init, ) fig, ax = plt.subplots() fig.set_size_inches(2.5, 1.5) ax.scatter( np.arange(len(losses)), -np.array(losses), label="train", alpha=0.25 ) set_loss_plot_axes(ax) posterior_samples = model.generate_posterior_samples( model.adata, num_samples=num_samples, batch_size=512 ) fig.savefig(loss_plot_path, facecolor="white", bbox_inches="tight") return model, posterior_samples else: # train validation procedure if ( guide_type == "velocity_auto_depth" ): # velocity_auto, not supported (velocity_auto_depth, fails with error) raise indices = np.arange(adata.shape[0]) train_ind, test_ind, cluster_train, cluster_test = train_test_split( indices, adata.obs.loc[:, cell_state].values, test_size=1 - train_size, random_state=seed, shuffle=False, ) train_batch_size = train_ind.shape[0] if batch_size == -1 else batch_size losses = model.train_faster_with_batch( max_epochs=max_epochs, batch_size=train_batch_size, indices=train_ind, log_every=log_every, lr=lr, use_gpu=use_gpu, seed=seed, patient_improve=patient_improve, patient_init=patient_init, ) posterior_samples = model.generate_posterior_samples( model.adata, num_samples=num_samples, indices=train_ind, batch_size=512 ) test_batch_size = test_ind.shape[0] if batch_size == -1 else batch_size if guide_type in {"auto", "auto_t0_constraint"}: new_guide = AutoGuideList( model.module._model, create_plates=model.module._model.create_plates ) new_guide.append( AutoNormal( poutine.block( model.module._model, expose=["cell_time", "u_read_depth", "s_read_depth"], ), init_scale=0.1, ) ) new_guide.append( poutine.block(model.module._guide[-1], hide_types=["param"]) ) losses_test = model.train_faster_with_batch( max_epochs=max_epochs, batch_size=test_batch_size, indices=test_ind, new_valid_guide=new_guide, log_every=log_every, lr=lr, use_gpu=use_gpu, seed=seed, elbo_name="-ELBO validation", ) elif guide_type in { "velocity_auto", "velocity_auto_t0_constraint", }: # velocity_auto, not supported (velocity_auto_depth fails with error) print("valid new guide") losses_test = model.train_faster_with_batch( max_epochs=max_epochs, batch_size=test_batch_size, indices=test_ind, log_every=log_every, lr=lr, use_gpu=use_gpu, seed=seed, ) else: raise pos_test = model.generate_posterior_samples( model.adata, num_samples=30, indices=test_ind, batch_size=512 ) fig, ax = plt.subplots() fig.set_size_inches(2.5, 1.5) ax.scatter( np.arange(len(losses)), -np.array(losses), label="train", alpha=0.25 ) ax.scatter( np.arange(len(losses_test)), -np.array(losses_test), label="validation", alpha=0.25, ) set_loss_plot_axes(ax) plt.legend() plt.savefig(loss_plot_path, facecolor="white", bbox_inches="tight") return posterior_samples, pos_test, train_ind, test_ind
[docs]def set_loss_plot_axes(ax): # ax.set_yscale('log') ax.set_yscale("symlog") ax.set_xlabel("Epochs") ax.set_ylabel("-ELBO")