Source code for pyroved.trainers.svi

from typing import Type, Optional

import torch
import pyro
import pyro.infer as infer
import pyro.optim as optim

from ..utils import set_deterministic_mode


[docs]class SVItrainer: """ Stochastic variational inference (SVI) trainer for unsupervised and class-conditioned VED models consisting one encoder and one decoder. Args: model: Initialized model. Must be a subclass of torch.nn.Module and have self.model and self.guide methods optimizer: Pyro optimizer (Defaults to Adam with learning rate 1e-3) loss: ELBO objective (Defaults to pyro.infer.Trace_ELBO) enumerate_parallel: Exact discrete enumeration for discrete latent variables seed: Enforces reproducibility kwargs: learning rate as 'lr' (Default: 5e-4) Example: Train a model with SVI trainer using default settings >>> # Initialize model >>> data_dim = (28, 28) >>> rvae = pyroved.models.trVAE(data_dim, latent_dim=2, coord=1) >>> # Initialize SVI trainer >>> trainer = SVItrainer(rvae) >>> # Train for 200 epochs: >>> for _ in range(200): >>> trainer.step(train_loader) >>> trainer.print_statistics() Train a model with SVI trainer with a "time"-dependent KL scaling factor >>> # Initialize model >>> data_dim = (28, 28) >>> rvae = pyroved.models.trVAE(data_dim, latent_dim=2, coord=1) >>> # Initialize SVI trainer >>> trainer = SVItrainer(rvae) >>> kl_scale = torch.linspace(1, 4, 50) # ramp-up KL scale factor from 1 to 4 during first 50 epochs >>> # Train >>> for e in range(100): >>> sc = kl_scale[e] if e < len(kl_scale) else kl_scale[-1] >>> trainer.step(train_loader, scale_factor=sc) >>> trainer.print_statistics() """ def __init__(self, model: Type[torch.nn.Module], optimizer: Type[optim.PyroOptim] = None, loss: Type[infer.ELBO] = None, enumerate_parallel: bool = False, seed: int = 1, **kwargs: float ) -> None: """ Initializes the trainer's parameters """ pyro.clear_param_store() set_deterministic_mode(seed) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if optimizer is None: lr = kwargs.get("lr", 1e-3) optimizer = optim.Adam({"lr": lr}) if loss is None: if enumerate_parallel: loss = infer.TraceEnum_ELBO( max_plate_nesting=1, strict_enumeration_warning=False) else: loss = infer.Trace_ELBO() guide = model.guide if enumerate_parallel: guide = infer.config_enumerate(guide, "parallel", expand=True) self.svi = infer.SVI(model.model, guide, optimizer, loss=loss) self.loss_history = {"training_loss": [], "test_loss": []} self.current_epoch = 0
[docs] def train(self, train_loader: Type[torch.utils.data.DataLoader], **kwargs: float) -> float: """ Trains a single epoch """ # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch returned by the data loader for data in train_loader: if len(data) == 1: # VAE mode x = data[0] loss = self.svi.step(x.to(self.device), **kwargs) else: # VED or cVAE mode x, y = data loss = self.svi.step( x.to(self.device), y.to(self.device), **kwargs) # do ELBO gradient and accumulate loss epoch_loss += loss return epoch_loss / len(train_loader.dataset)
[docs] def evaluate(self, test_loader: Type[torch.utils.data.DataLoader], **kwargs: float) -> float: """ Evaluates current models state on a single epoch """ # initialize loss accumulator test_loss = 0. # compute the loss over the entire test set with torch.no_grad(): for data in test_loader: if len(data) == 1: # VAE mode x = data[0] loss = self.svi.step(x.to(self.device), **kwargs) else: # VED or cVAE mode x, y = data loss = self.svi.step( x.to(self.device), y.to(self.device), **kwargs) test_loss += loss return test_loss / len(test_loader.dataset)
[docs] def step(self, train_loader: Type[torch.utils.data.DataLoader], test_loader: Optional[Type[torch.utils.data.DataLoader]] = None, **kwargs: float) -> None: """ Single training and (optionally) evaluation step Args: train_loader: Pytorch’s dataloader object with training data test_loader: (Optional) Pytorch’s dataloader object with test data **scale_factor: Scale factor for KL divergence. See e.g. https://arxiv.org/abs/1804.03599 Default value is 1 (i.e. no scaling) """ train_loss = self.train(train_loader, **kwargs) self.loss_history["training_loss"].append(train_loss) if test_loader is not None: test_loss = self.evaluate(test_loader, **kwargs) self.loss_history["test_loss"].append(test_loss) self.current_epoch += 1
[docs] def print_statistics(self) -> None: """ Prints training and test (if any) losses for current epoch """ e = self.current_epoch if len(self.loss_history["test_loss"]) > 0: template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}' print(template.format(e, self.loss_history["training_loss"][-1], self.loss_history["test_loss"][-1])) else: template = 'Epoch: {} Training loss: {:.4f}' print(template.format(e, self.loss_history["training_loss"][-1]))