Source code for pyrovelocity._trainer

import math
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Union

import mlflow
import numpy as np
import pyro
import scipy
import torch
from pyro.infer import Trace_ELBO
from pyro.infer import TraceEnum_ELBO
from pyro.infer.autoguide.guides import AutoGuideList
from pyro.optim.clipped_adam import ClippedAdam
from pyro.optim.optim import PyroOptim
from scvi.dataloaders import DataSplitter
from scvi.train import PyroTrainingPlan
from scvi.train import TrainRunner

from pyrovelocity.utils import _get_fn_args_from_batch

from ._velocity_module import VelocityModule


[docs]class VelocityAdam(ClippedAdam):
[docs] def step(self, closure: Optional[Callable] = None) -> Optional[Any]: loss = None if closure is not None: loss = closure() for group in self.param_groups: group["lr"] *= group["lrd"] for p in group["params"]: if p.grad is None: continue grad = p.grad.data grad.clamp_(-group["clip_norm"], group["clip_norm"]) # freeze non-velocity gene gradient grad[grad.isnan()] = 0.0 state = self.state[p] # State initialization if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(grad) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 if group["weight_decay"] != 0: grad = grad.add(p.data, alpha=group["weight_decay"]) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 p.data.addcdiv_(exp_avg, denom, value=-step_size) return loss
[docs]def VelocityClippedAdam(optim_args: Dict[str, float]) -> PyroOptim: """ Wraps :class:`pyro.optim.clipped_adam.ClippedAdam` with :class:`~pyro.optim.optim.PyroOptim`. """ return PyroOptim(VelocityAdam, optim_args)
[docs]class EnumTrainingPlan(PyroTrainingPlan):
[docs] def __init__( self, pyro_velocity: VelocityModule, optim: Optional[pyro.optim.PyroOptim] = None, ): super().__init__( pyro_velocity, TraceEnum_ELBO(strict_enumeration_warning=True), optim ) self.svi = pyro.infer.SVI( model=self.module.model, guide=self.module.guide, optim=self.optim, loss=self.loss_fn, ) self.n_elem = self.module.num_genes * self.module.num_cells * 2
[docs] def training_step(self, batch, batch_idx, optimizer_idx=0): args, kwargs = _get_fn_args_from_batch(batch) loss = self.svi.step(*args, **kwargs) return { "train_step_loss": loss, "num_elem": args[0].shape[0] * args[0].shape[1] * 2, }
[docs] def training_epoch_end(self, outputs): n_batch, elbo = 0, 0 n_elem = 0 for tensors in outputs: elbo += tensors["train_step_loss"] n_batch += 1 # n_elem += tensors['num_elem'] self.log("elbo_train", elbo / n_batch, prog_bar=True, on_epoch=True)
# self.log("elbo_train", elbo / self.n_elem, prog_bar=True, on_epoch=True) # self.log("elbo_train", elbo / n_elem, prog_bar=True, on_epoch=True)
[docs] def validation_step(self, batch, batch_idx): args, kwargs = _get_fn_args_from_batch(batch) loss = self.svi.evaluate_loss(*args, **kwargs) return { "valid_step_loss": loss, "num_elem": args[0].shape[0] * args[0].shape[1] * 2, }
[docs] def validation_epoch_end(self, outputs): """Aggregate validation step information.""" n_batch, elbo = 0, 0 n_elem = 0 for tensors in outputs: elbo += tensors["valid_step_loss"] n_batch += 1 ##n_elem += tensors['num_elem'] self.log("elbo_validation", elbo / n_batch, prog_bar=True, on_epoch=True)
# self.log("elbo_validation", elbo / self.n_elem, prog_bar=True, on_epoch=True) ##self.log("elbo_validation", elbo / n_elem, prog_bar=True, on_epoch=True)
[docs]class VelocityTrainingMixin:
[docs] def train( self, use_gpu: Optional[Union[str, int, bool]] = 0, early_stopping: bool = False, seed: int = 99, lr: float = 1e-3, train_size: float = 1.0, valid_size: float = 0.0, batch_size: int = 256, max_epochs: int = 100, check_val_every_n_epoch: Optional[int] = 1, patience: int = 10, min_delta: float = 0.0, **kwargs, ): print("base train function") pyro.clear_param_store() pyro.set_rng_seed(seed) data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=valid_size, batch_size=batch_size, use_gpu=use_gpu, ) training_plan = EnumTrainingPlan( self.module, VelocityClippedAdam({"lr": lr, "lrd": 0.9999}) ) runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, check_val_every_n_epoch=check_val_every_n_epoch, max_epochs=max_epochs, use_gpu=use_gpu, early_stopping=early_stopping, early_stopping_patience=patience, early_stopping_min_delta=min_delta, **kwargs, ) return runner()
[docs] def train_faster( self, use_gpu: Optional[Union[str, int, bool]] = 0, seed: int = 99, lr: float = 0.01, max_epochs: int = 5000, log_every: int = 100, patient_init: int = 45, patient_improve: float = 0.001, ) -> List[float]: """this method input all adata input gpu for faster IO, increase larger dataset by 5-6 fold, however should not be used for >20k cells with less than 40GB GPU memory, this ignores validation cells""" print("train_faster") if (use_gpu is False) or (use_gpu == "cpu") or (use_gpu == -1): device = "cpu" else: device = f"cuda:{use_gpu}" pyro.clear_param_store() pyro.set_rng_seed(seed) pyro.enable_validation(True) optim = VelocityClippedAdam({"lr": lr, "lrd": 0.1 ** (1 / max_epochs)}) self.module._model = self.module._model.to(device) print("TraceEnum") svi = pyro.infer.SVI( self.module._model, self.module._guide, optim, # TraceEnum_ELBO(strict_enumeration_warning=True), Trace_ELBO(strict_enumeration_warning=True), ) normalizer = self.adata.shape[0] * self.adata.shape[1] * 2 u = torch.tensor( np.array(self.adata.layers["raw_unspliced"].toarray(), dtype="float32") if scipy.sparse.issparse(self.adata.layers["raw_unspliced"]) else self.adata.layers["raw_unspliced"], dtype=torch.float32, ).to(device) s = torch.tensor( np.array(self.adata.layers["raw_spliced"].toarray(), dtype="float32") if scipy.sparse.issparse(self.adata.layers["raw_spliced"]) else self.adata.layers["raw_spliced"], dtype=torch.float32, ).to(device) epsilon = 1e-6 u_library = torch.tensor( # np.array(self.adata.obs.u_lib_size, dtype="float32"), dtype=torch.float32 np.array(np.log(self.adata.obs.u_lib_size_raw + epsilon), dtype="float32"), dtype=torch.float32, ).to(device) s_library = torch.tensor( # np.array(self.adata.obs.s_lib_size, dtype="float32"), dtype=torch.float32 np.array(np.log(self.adata.obs.s_lib_size_raw + epsilon), dtype="float32"), dtype=torch.float32, ).to(device) u_library_mean = ( torch.tensor( # np.array(self.adata.obs.u_lib_size_mean, dtype="float32"), np.mean(np.log(self.adata.obs.u_lib_size_raw + epsilon)), dtype=torch.float32, ) .expand(u_library.shape) .to(device) ) s_library_mean = ( torch.tensor( # np.array(self.adata.obs.s_lib_size_mean, dtype="float32"), np.mean(np.log(self.adata.obs.s_lib_size_raw + epsilon)), dtype=torch.float32, ) .expand(u_library.shape) .to(device) ) u_library_scale = ( torch.tensor( # np.array(self.adata.obs.u_lib_size_scale, dtype="float32"), np.std(np.log(self.adata.obs.u_lib_size_raw + epsilon)), dtype=torch.float32, ) .expand(u_library.shape) .to(device) ) s_library_scale = ( torch.tensor( # np.array(self.adata.obs.s_lib_size_scale, dtype="float32"), np.std(np.log(self.adata.obs.s_lib_size_raw + epsilon)), dtype=torch.float32, ) .expand(u_library.shape) .to(device) ) print(u_library_scale.shape) print(u.shape) print(u_library.shape) if "pyro_cell_state" in self.adata.obs.columns: cell_state = torch.tensor( np.array(self.adata.obs.pyro_cell_state, dtype="float32"), dtype=torch.float32, ).to(device) else: cell_state = None losses = [] patience = patient_init for step in range(max_epochs): if cell_state is None: elbos = ( svi.step( u, s, u_library.reshape(-1, 1), s_library.reshape(-1, 1), u_library_mean.reshape(-1, 1), s_library_mean.reshape(-1, 1), u_library_scale.reshape(-1, 1), s_library_scale.reshape(-1, 1), None, None, ) / normalizer ) else: elbos = ( svi.step( u, s, u_library.reshape(-1, 1), s_library.reshape(-1, 1), u_library_mean.reshape(-1, 1), s_library_mean.reshape(-1, 1), u_library_scale.reshape(-1, 1), s_library_scale.reshape(-1, 1), None, cell_state.reshape(-1, 1), ) / normalizer ) if (step == 0) or ( ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) ): mlflow.log_metric("-ELBO", -elbos, step=step + 1) print(f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}") if step > log_every: if (losses[-1] - elbos) < losses[-1] * patient_improve: patience -= 1 else: patience = patient_init if patience <= 0: break losses.append(elbos) mlflow.log_metric("-ELBO", -elbos, step=step + 1) mlflow.log_metric("real_epochs", step + 1) print(f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}") return losses
[docs] def train_faster_with_batch( self, use_gpu: Optional[Union[str, int, bool]] = 0, seed: int = 99, lr: float = 1e-2, max_epochs: int = 5000, log_every: int = 100, indices: Optional[Sequence[int]] = None, batch_size: Optional[int] = None, new_valid_guide: Optional[AutoGuideList] = None, patient_init: int = 45, patient_improve: float = 0.0, elbo_name: str = "-ELBO", ): print("train_faster_with_batch") if (use_gpu is False) or (use_gpu == "cpu") or (use_gpu == -1): device = "cpu" else: device = f"cuda:{use_gpu}" pyro.clear_param_store() pyro.set_rng_seed(seed) pyro.enable_validation(True) adata = self._validate_anndata(self.adata) scdl = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) optim = VelocityClippedAdam({"lr": lr, "lrd": 0.1 ** (1 / max_epochs)}) self.module._model = self.module._model.to(device) if new_valid_guide is None: svi = pyro.infer.SVI( self.module._model, self.module._guide, optim, Trace_ELBO(strict_enumeration_warning=True), ) else: svi = pyro.infer.SVI( self.module._model, new_valid_guide, optim, Trace_ELBO(strict_enumeration_warning=True), ) losses = [] patience = patient_init # normalizer = self.adata.shape[0] * self.adata.shape[1] * 2 for step in range(max_epochs): n_batch = 0 elbos = 0 for tensor_dict in scdl: args, kwargs = _get_fn_args_from_batch(tensor_dict) args = [a.to(device) if a is not None else a for a in args] loss = svi.step(*args, **kwargs) elbos += loss n_batch += 1 # elbos = elbos / normalizer #n_batch elbos = elbos / n_batch if (step == 0) or ( ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) ): mlflow.log_metric("-ELBO", -elbos, step=step + 1) print(f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}") if step > log_every: if (losses[-1] - elbos) < losses[-1] * patient_improve: patience -= 1 # print(f"step {step: >4d} loss = {elbos:0.6g} patience = {patience}") else: patience = patient_init if patience <= 0: break losses.append(elbos) mlflow.log_metric("-ELBO", -elbos, step=step + 1) mlflow.log_metric("real_epochs", step + 1) print(f"step {step: >4d} loss = {elbos:0.6g} patience = {patience}") return losses