pyrovelocity.api#

Model training API for Pyro-Velocity.

pyrovelocity.api.set_loss_plot_axes(ax)[source]#
pyrovelocity.api.train_model(adata, guide_type='auto', model_type='auto', svi_train=False, batch_size=-1, train_size=1.0, use_gpu=0, likelihood='Poisson', num_samples=30, log_every=100, cell_state='clusters', patient_improve=0.0005, patient_init=30, seed=99, lr=0.01, max_epochs=3000, include_prior=True, library_size=True, offset=False, input_type='raw', cell_specific_kinetics=None, kinetics_num=2, loss_plot_path='loss_plot.png')[source]#

Train a PyroVelocity model to provide probabilistic estimates of RNA velocity for single-cell RNA sequencing data with quantified splice variants.

Parameters:
  • adata (AnnData) – An AnnData object containing the input data.

  • guide_type (str, optional) – The type of guide function for the Pyro model. Default is “auto”.

  • model_type (str, optional) – The type of Pyro model. Default is “auto”.

  • svi_train (bool, optional) – Whether to use Stochastic Variational Inference for training. Default is False.

  • batch_size (int, optional) – Batch size for training. Default is -1, which indicates using the full dataset.

  • train_size (float, optional) – Proportion of data to be used for training. Default is 1.0.

  • use_gpu (int, optional) – Whether to use GPU for training. Default is 0, which indicates not using GPU.

  • likelihood (str, optional) – Likelihood function for the Pyro model. Default is “Poisson”.

  • num_samples (int, optional) – Number of posterior samples. Default is 30.

  • log_every (int, optional) – Frequency of logging progress. Default is 100.

  • cell_state (str, optional) – Cell state attribute in the AnnData object. Default is “clusters”.

  • patient_improve (float, optional) – Minimum improvement in training loss for early stopping. Default is 5e-4.

  • patient_init (int, optional) – Number of initial training epochs before early stopping is enabled. Default is 30.

  • seed (int, optional) – Random seed for reproducibility. Default is 99.

  • lr (float, optional) – Learning rate for the optimizer. Default is 0.01.

  • max_epochs (int, optional) – Maximum number of training epochs. Default is 3000.

  • include_prior (bool, optional) – Whether to include prior information in the model. Default is True.

  • library_size (bool, optional) – Whether to correct for library size. Default is True.

  • offset (bool, optional) – Whether to add an offset to the model. Default is False.

  • input_type (str, optional) – Type of input data. Default is “raw”.

  • cell_specific_kinetics (Optional[str], optional) – Name of the attribute containing cell-specific kinetics information. Default is None.

  • kinetics_num (int, optional) – Number of kinetics parameters. Default is 2.

  • loss_plot_path (str, optional) – Path to save the loss plot. Default is “loss_plot.png”.

Returns:

A tuple containing the trained PyroVelocity model and a dictionary of posterior samples.

Return type:

Tuple[PyroVelocity, Dict[str, ndarray]]

Examples

>>> from pyrovelocity.api import train_model
>>> from pyrovelocity.utils import generate_sample_data
>>> from pyrovelocity.data import copy_raw_counts
>>> adata = generate_sample_data(random_seed=99)
>>> copy_raw_counts(adata)
>>> model, posterior_samples = train_model(adata, seed=99, max_epochs=200, loss_plot_path="loss_plot_docs.png")