"""
trvae.py
=========
Variational autoencoder with rotational and/or translational invariances
Created by Maxim Ziatdinov (email: ziatdinovmax@gmail.com)
"""
from typing import Optional, Tuple, Union
import pyro
import pyro.distributions as dist
import torch
import torch.tensor as tt
from .base import baseVAE
from ..nets import fcDecoderNet, fcEncoderNet, sDecoderNet
from ..utils import (generate_grid, generate_latent_grid, get_sampler,
plot_img_grid, plot_spect_grid, set_deterministic_mode,
to_onehot, transform_coordinates)
[docs]class trVAE(baseVAE):
"""
Variational autoencoder that enforces rotational and/or translational invariances
Args:
data_dim:
Dimensionality of the input data; use (h x w) for images
or (length,) for spectra.
latent_dim:
Number of latent dimensions.
coord:
For 2D systems, *coord=0* is vanilla VAE, *coord=1* enforces
rotational invariance, *coord=2* enforces invariance to translations,
and *coord=3* enforces both rotational and translational invariances.
For 1D systems, *coord=0* is vanilla VAE and *coord>0* enforces
transaltional invariance.
num_classes:
Number of classes (if any) for class-conditioned (t)(r)VAE.
hidden_dim_e:
Number of hidden units per each layer in encoder (inference network).
The default value is 128.
hidden_dim_d:
Number of hidden units per each layer in decoder (generator network).
The default value is 128.
num_layers_e:
Number of layers in encoder (inference network).
The default value is 2.
num_layers_d:
Number of layers in decoder (generator network).
The default value is 2.
activation:
Non-linear activation for inner layers of encoder and decoder.
The available activations are ReLU ('relu'), leaky ReLU ('lrelu'),
hyberbolic tangent ('tanh'), and softplus ('softplus')
The default activation is 'tanh'.
sampler_d:
Decoder sampler, as defined as p(x|z) = sampler(decoder(z)).
The available samplers are 'bernoulli', 'continuous_bernoulli',
and 'gaussian' (Default: 'bernoulli').
sigmoid_d:
Sigmoid activation for the decoder output (Default: True)
seed:
Seed used in torch.manual_seed(seed) and
torch.cuda.manual_seed_all(seed)
kwargs:
Additional keyword arguments are *dx_prior* and *dy_prior* for setting
a translational prior(s), and *decoder_sig* for setting sigma
in the decoder's sampler when it is set to "gaussian".
Example:
Initialize a VAE model with rotational invariance
>>> data_dim = (28, 28)
>>> ssvae = trVAE(data_dim, latent_dim=2, coord=1)
Initialize a class-conditioned VAE model with rotational invariance
for dataset that has 10 classes
>>> data_dim = (28, 28)
>>> ssvae = trVAE(data_dim, latent_dim=2, num_classes=10, coord=1)
"""
def __init__(self,
data_dim: Tuple[int],
latent_dim: int = 2,
coord: int = 3,
num_classes: int = 0,
hidden_dim_e: int = 128,
hidden_dim_d: int = 128,
num_layers_e: int = 2,
num_layers_d: int = 2,
activation: str = "tanh",
sampler_d: str = "bernoulli",
sigmoid_d: bool = True,
seed: int = 1,
**kwargs: float
) -> None:
"""
Initializes trVAE's modules and parameters
"""
super(trVAE, self).__init__()
pyro.clear_param_store()
set_deterministic_mode(seed)
self.ndim = len(data_dim)
if self.ndim == 1 and coord > 0:
coord = 1
self.encoder_z = fcEncoderNet(
data_dim, latent_dim+coord, 0, hidden_dim_e,
num_layers_e, activation, softplus_out=True)
if coord not in [0, 1, 2, 3]:
raise ValueError("'coord' argument must be 0, 1, 2 or 3")
dnet = sDecoderNet if coord in [1, 2, 3] else fcDecoderNet
self.decoder = dnet(
data_dim, latent_dim, num_classes, hidden_dim_d,
num_layers_d, activation, sigmoid_out=sigmoid_d)
self.sampler_d = get_sampler(sampler_d, **kwargs)
self.z_dim = latent_dim + coord
self.coord = coord
self.num_classes = num_classes
self.grid = generate_grid(data_dim).to(self.device)
dx_pri = tt(kwargs.get("dx_prior", 0.1))
dy_pri = kwargs.get("dy_prior", dx_pri.clone())
t_prior = tt([dx_pri, dy_pri]) if self.ndim == 2 else dx_pri
self.t_prior = t_prior.to(self.device)
self.to(self.device)
[docs] def model(self,
x: torch.Tensor,
y: Optional[torch.Tensor] = None,
**kwargs: float) -> None:
"""
Defines the model p(x|z)p(z)
"""
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
beta = kwargs.get("scale_factor", 1.)
reshape_ = torch.prod(tt(x.shape[1:])).item()
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
# sample from prior (value will be sampled by guide when computing the ELBO)
with pyro.poutine.scale(scale=beta):
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
if self.coord > 0: # rotationally- and/or translationaly-invariant mode
# Split latent variable into parts for rotation
# and/or translation and image content
phi, dx, z = self.split_latent(z)
if torch.sum(dx.abs()) != 0:
dx = (dx * self.t_prior).unsqueeze(1)
# transform coordinate grid
grid = self.grid.expand(x.shape[0], *self.grid.shape)
x_coord_prime = transform_coordinates(grid, phi, dx)
# Add class label (if any)
if y is not None:
z = torch.cat([z, y], dim=-1)
# decode the latent code z together with the transformed coordinates (if any)
dec_args = (x_coord_prime, z) if self.coord else (z,)
loc = self.decoder(*dec_args)
# score against actual images ("binary cross-entropy loss")
pyro.sample(
"obs", self.sampler_d(loc.view(-1, reshape_)).to_event(1),
obs=x.view(-1, reshape_))
[docs] def guide(self,
x: torch.Tensor,
y: Optional[torch.Tensor] = None,
**kwargs: float) -> None:
"""
Defines the guide q(z|x)
"""
# register PyTorch module `encoder_z` with Pyro
pyro.module("encoder_z", self.encoder_z)
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
beta = kwargs.get("scale_factor", 1.)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder_z(x)
# sample the latent code z
with pyro.poutine.scale(scale=beta):
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
[docs] def split_latent(self, z: torch.Tensor) -> Tuple[torch.Tensor]:
"""
Split latent variable into parts for rotation
and/or translation and image content
"""
# For 1D, there is only a translation
if self.ndim == 1:
dx = z[:, 0:1]
z = z[:, 1:]
return None, dx, z
phi, dx = tt(0), tt(0)
# rotation + translation
if self.coord == 3:
phi = z[:, 0] # encoded angle
dx = z[:, 1:3] # translation
z = z[:, 3:] # image content
# translation only
elif self.coord == 2:
dx = z[:, :2]
z = z[:, 2:]
# rotation only
elif self.coord == 1:
phi = z[:, 0]
z = z[:, 1:]
return phi, dx, z
[docs] def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
"""
Encodes data using a trained inference (encoder) network
Args:
x_new:
Data to encode with a trained trVAE. The new data must have
the same dimensions (images height and width or spectra length)
as the one used for training.
kwargs:
Batch size as 'batch_size' (for encoding large volumes of data)
"""
z = self._encode(x_new)
z_loc, z_scale = z.split(self.z_dim, 1)
return z_loc, z_scale
[docs] def decode(self,
z: torch.Tensor,
y: torch.Tensor = None,
**kwargs: int) -> torch.Tensor:
"""
Decodes a batch of latent coordnates
Args:
z: Latent coordinates (without rotational and translational parts)
y: Class (if any) as a batch of one-hot vectors
kwargs: Batch size as 'batch_size'
"""
z = z.to(self.device)
if y is not None:
z = torch.cat([z, y.to(self.device)], -1)
loc = self._decode(z, **kwargs)
return loc
[docs] def manifold2d(self, d: int, plot: bool = True,
**kwargs: Union[str, int]) -> torch.Tensor:
"""
Plots a learned latent manifold in the image space
Args:
d: Grid size
plot: Plots the generated manifold (Default: True)
kwargs: Keyword arguments include 'label' for class label (if any),
custom min/max values for grid boundaries passed as 'z_coord'
(e.g. z_coord = [-3, 3, -3, 3]) and plot parameters
('padding', 'padding_value', 'cmap', 'origin', 'ylim')
"""
if self.num_classes > 0:
cls = tt(kwargs.get("label", 0))
if cls.ndim < 2:
cls = to_onehot(cls.unsqueeze(0), self.num_classes)
z, (grid_x, grid_y) = generate_latent_grid(d, **kwargs)
z = z.to(self.device)
if self.num_classes > 0:
z = torch.cat([z, cls.repeat(z.shape[0], 1)], dim=-1)
z = [z]
if self.coord:
grid = [self.grid.expand(z[0].shape[0], *self.grid.shape)]
z = grid + z
with torch.no_grad():
loc = self.decoder(*z).cpu()
if plot:
if self.ndim == 2:
plot_img_grid(
loc, d,
extent=[grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()],
**kwargs)
elif self.ndim == 1:
plot_spect_grid(loc, d, **kwargs)
return loc