pyrovelocity.api#
Model training API for Pyro-Velocity.
- pyrovelocity.api.train_model(adata, guide_type='auto', model_type='auto', svi_train=False, batch_size=-1, train_size=1.0, use_gpu=0, likelihood='Poisson', num_samples=30, log_every=100, cell_state='clusters', patient_improve=0.0005, patient_init=30, seed=99, lr=0.01, max_epochs=3000, include_prior=True, library_size=True, offset=False, input_type='raw', cell_specific_kinetics=None, kinetics_num=2, loss_plot_path='loss_plot.png')[source]#
Train a PyroVelocity model to provide probabilistic estimates of RNA velocity for single-cell RNA sequencing data with quantified splice variants.
- Parameters:
adata (AnnData) – An AnnData object containing the input data.
guide_type (str, optional) – The type of guide function for the Pyro model. Default is “auto”.
model_type (str, optional) – The type of Pyro model. Default is “auto”.
svi_train (bool, optional) – Whether to use Stochastic Variational Inference for training. Default is False.
batch_size (int, optional) – Batch size for training. Default is -1, which indicates using the full dataset.
train_size (float, optional) – Proportion of data to be used for training. Default is 1.0.
use_gpu (int, optional) – Whether to use GPU for training. Default is 0, which indicates not using GPU.
likelihood (str, optional) – Likelihood function for the Pyro model. Default is “Poisson”.
num_samples (int, optional) – Number of posterior samples. Default is 30.
log_every (int, optional) – Frequency of logging progress. Default is 100.
cell_state (str, optional) – Cell state attribute in the AnnData object. Default is “clusters”.
patient_improve (float, optional) – Minimum improvement in training loss for early stopping. Default is 5e-4.
patient_init (int, optional) – Number of initial training epochs before early stopping is enabled. Default is 30.
seed (int, optional) – Random seed for reproducibility. Default is 99.
lr (float, optional) – Learning rate for the optimizer. Default is 0.01.
max_epochs (int, optional) – Maximum number of training epochs. Default is 3000.
include_prior (bool, optional) – Whether to include prior information in the model. Default is True.
library_size (bool, optional) – Whether to correct for library size. Default is True.
offset (bool, optional) – Whether to add an offset to the model. Default is False.
input_type (str, optional) – Type of input data. Default is “raw”.
cell_specific_kinetics (Optional[str], optional) – Name of the attribute containing cell-specific kinetics information. Default is None.
kinetics_num (int, optional) – Number of kinetics parameters. Default is 2.
loss_plot_path (str, optional) – Path to save the loss plot. Default is “loss_plot.png”.
- Returns:
A tuple containing the trained PyroVelocity model and a dictionary of posterior samples.
- Return type:
Tuple[PyroVelocity, Dict[str, ndarray]]
Examples
>>> from pyrovelocity.api import train_model >>> from pyrovelocity.utils import generate_sample_data >>> from pyrovelocity.data import copy_raw_counts >>> adata = generate_sample_data(random_seed=99) >>> copy_raw_counts(adata) >>> model, posterior_samples = train_model(adata, seed=99, max_epochs=200, loss_plot_path="loss_plot_docs.png")