Trainers

SVI Trainer

class pyroved.trainers.SVItrainer(model, optimizer=None, loss=None, enumerate_parallel=False, seed=1, **kwargs)[source]

Bases: object

Stochastic variational inference (SVI) trainer for unsupervised and class-conditioned VED models consisting one encoder and one decoder.

Parameters
  • model (Type[Module]) – Initialized model. Must be a subclass of torch.nn.Module and have self.model and self.guide methods

  • optimizer (Optional[Type[PyroOptim]]) – Pyro optimizer (Defaults to Adam with learning rate 1e-3)

  • loss (Optional[Type[ELBO]]) – ELBO objective (Defaults to pyro.infer.Trace_ELBO)

  • enumerate_parallel (bool) – Exact discrete enumeration for discrete latent variables

  • seed (int) – Enforces reproducibility

  • kwargs (float) – learning rate as ‘lr’ (Default: 5e-4)

Example:

Train a model with SVI trainer using default settings

>>> # Initialize model
>>> data_dim = (28, 28)
>>> rvae = pyroved.models.trVAE(data_dim, latent_dim=2, coord=1)
>>> # Initialize SVI trainer
>>> trainer = SVItrainer(rvae)
>>> # Train for 200 epochs:
>>> for _ in range(200):
>>>     trainer.step(train_loader)
>>>     trainer.print_statistics()

Train a model with SVI trainer with a “time”-dependent KL scaling factor

>>> # Initialize model
>>> data_dim = (28, 28)
>>> rvae = pyroved.models.trVAE(data_dim, latent_dim=2, coord=1)
>>> # Initialize SVI trainer
>>> trainer = SVItrainer(rvae)
>>> kl_scale = torch.linspace(1, 4, 50) # ramp-up KL scale factor from 1 to 4 during first 50 epochs
>>> # Train
>>> for e in range(100):
>>>     sc = kl_scale[e] if e < len(kl_scale) else kl_scale[-1]
>>>     trainer.step(train_loader, scale_factor=sc)
>>>     trainer.print_statistics()
train(train_loader, **kwargs)[source]

Trains a single epoch

Return type

float

evaluate(test_loader, **kwargs)[source]

Evaluates current models state on a single epoch

Return type

float

step(train_loader, test_loader=None, **kwargs)[source]

Single training and (optionally) evaluation step

Parameters
  • train_loader (Type[DataLoader]) – Pytorch’s dataloader object with training data

  • test_loader (Optional[Type[DataLoader]]) – (Optional) Pytorch’s dataloader object with test data

  • **scale_factor – Scale factor for KL divergence. See e.g. https://arxiv.org/abs/1804.03599 Default value is 1 (i.e. no scaling)

Return type

None

print_statistics()[source]

Prints training and test (if any) losses for current epoch

Return type

None

auxSVI Trainer

class pyroved.trainers.auxSVItrainer(model, optimizer=None, seed=1, **kwargs)[source]

Bases: object

Stochastic variational inference (SVI) trainer for variational models with auxillary losses

Parameters
  • model (Type[Module]) – Initialized model. Must be a subclass of torch.nn.Module and have self.model and self.guide methods

  • optimizer (Optional[Type[PyroOptim]]) – Pyro optimizer (Defaults to Adam with learning rate 5e-4)

  • seed (int) – Enforces reproducibility

  • kwargs (float) – learning rate as ‘lr’ (Default: 5e-4)

Example:

>>> # Initialize model for semi supervised learning
>>> data_dim = (28, 28)
>>> ssvae = pyroved.models.sstrVAE(data_dim, latent_dim=2, num_classes=10, coord=1)
>>> # Initialize SVI trainer for models with auxiliary loss terms
>>> trainer = auxSVItrainer(ssvae)
>>> # Train for 200 epochs:
>>> for _ in range(200):
>>>     trainer.step(loader_unsuperv, loader_superv, loader_valid)
>>>     trainer.print_statistics()
compute_loss(xs, ys=None, **kwargs)[source]

Computes basic and auxillary losses

Return type

float

train(loader_unsup, loader_sup, **kwargs)[source]

Train a single epoch

Return type

float

evaluate(loader_val)[source]

Evaluates model’s current state on labeled test data

Return type

None

step(loader_unsup, loader_sup, loader_val=None, **kwargs)[source]

Single train (and evaluation, if any) step.

Parameters
  • loader_unsup (DataLoader) – Pytorch’s dataloader with unlabeled training data

  • loader_sup (DataLoader) – Pytorch’s dataloader with labeled training data

  • loader_val (Optional[DataLoader]) – Pytorch’s dataloader with validation data

  • **scale_factor – Scale factor for KL divergence. See e.g. https://arxiv.org/abs/1804.03599 Default value is 1 (i.e. no scaling)

  • **aux_loss_multiplier – Hyperparameter that modulates the importance of the auxiliary loss term. See Eq. 9 in https://arxiv.org/abs/1406.5298. Default values is 20.

Return type

None

print_statistics()[source]

Print training and test (if any) losses for current epoch

Return type

None