Trainers¶
SVI Trainer¶
-
class
pyroved.trainers.
SVItrainer
(model, optimizer=None, loss=None, enumerate_parallel=False, seed=1, **kwargs)[source]¶ Bases:
object
Stochastic variational inference (SVI) trainer for unsupervised and class-conditioned VED models consisting one encoder and one decoder.
- Parameters
model (
Type
[Module
]) – Initialized model. Must be a subclass of torch.nn.Module and have self.model and self.guide methodsoptimizer (
Optional
[Type
[PyroOptim
]]) – Pyro optimizer (Defaults to Adam with learning rate 1e-3)loss (
Optional
[Type
[ELBO
]]) – ELBO objective (Defaults to pyro.infer.Trace_ELBO)enumerate_parallel (
bool
) – Exact discrete enumeration for discrete latent variablesseed (
int
) – Enforces reproducibilitykwargs (
float
) – learning rate as ‘lr’ (Default: 5e-4)
Example:
Train a model with SVI trainer using default settings
>>> # Initialize model >>> data_dim = (28, 28) >>> rvae = pyroved.models.trVAE(data_dim, latent_dim=2, coord=1) >>> # Initialize SVI trainer >>> trainer = SVItrainer(rvae) >>> # Train for 200 epochs: >>> for _ in range(200): >>> trainer.step(train_loader) >>> trainer.print_statistics()
Train a model with SVI trainer with a “time”-dependent KL scaling factor
>>> # Initialize model >>> data_dim = (28, 28) >>> rvae = pyroved.models.trVAE(data_dim, latent_dim=2, coord=1) >>> # Initialize SVI trainer >>> trainer = SVItrainer(rvae) >>> kl_scale = torch.linspace(1, 4, 50) # ramp-up KL scale factor from 1 to 4 during first 50 epochs >>> # Train >>> for e in range(100): >>> sc = kl_scale[e] if e < len(kl_scale) else kl_scale[-1] >>> trainer.step(train_loader, scale_factor=sc) >>> trainer.print_statistics()
-
evaluate
(test_loader, **kwargs)[source]¶ Evaluates current models state on a single epoch
- Return type
float
-
step
(train_loader, test_loader=None, **kwargs)[source]¶ Single training and (optionally) evaluation step
- Parameters
train_loader (
Type
[DataLoader
]) – Pytorch’s dataloader object with training datatest_loader (
Optional
[Type
[DataLoader
]]) – (Optional) Pytorch’s dataloader object with test data**scale_factor – Scale factor for KL divergence. See e.g. https://arxiv.org/abs/1804.03599 Default value is 1 (i.e. no scaling)
- Return type
None
auxSVI Trainer¶
-
class
pyroved.trainers.
auxSVItrainer
(model, optimizer=None, seed=1, **kwargs)[source]¶ Bases:
object
Stochastic variational inference (SVI) trainer for variational models with auxillary losses
- Parameters
model (
Type
[Module
]) – Initialized model. Must be a subclass of torch.nn.Module and have self.model and self.guide methodsoptimizer (
Optional
[Type
[PyroOptim
]]) – Pyro optimizer (Defaults to Adam with learning rate 5e-4)seed (
int
) – Enforces reproducibilitykwargs (
float
) – learning rate as ‘lr’ (Default: 5e-4)
Example:
>>> # Initialize model for semi supervised learning >>> data_dim = (28, 28) >>> ssvae = pyroved.models.sstrVAE(data_dim, latent_dim=2, num_classes=10, coord=1) >>> # Initialize SVI trainer for models with auxiliary loss terms >>> trainer = auxSVItrainer(ssvae) >>> # Train for 200 epochs: >>> for _ in range(200): >>> trainer.step(loader_unsuperv, loader_superv, loader_valid) >>> trainer.print_statistics()
-
step
(loader_unsup, loader_sup, loader_val=None, **kwargs)[source]¶ Single train (and evaluation, if any) step.
- Parameters
loader_unsup (
DataLoader
) – Pytorch’s dataloader with unlabeled training dataloader_sup (
DataLoader
) – Pytorch’s dataloader with labeled training dataloader_val (
Optional
[DataLoader
]) – Pytorch’s dataloader with validation data**scale_factor – Scale factor for KL divergence. See e.g. https://arxiv.org/abs/1804.03599 Default value is 1 (i.e. no scaling)
**aux_loss_multiplier – Hyperparameter that modulates the importance of the auxiliary loss term. See Eq. 9 in https://arxiv.org/abs/1406.5298. Default values is 20.
- Return type
None