"""
jtrvae.py
=========
Variational autoencoder for learning (jointly) discrete and
continuous latent representations on data with arbitrary rotations
and/or translations
Created by Maxim Ziatdinov (email: ziatdinovmax@gmail.com)
"""
from typing import Tuple, Union
import pyro
import pyro.distributions as dist
import torch
import torch.tensor as tt
from .base import baseVAE
from ..nets import fcDecoderNet, jfcEncoderNet, sDecoderNet
from ..utils import (generate_grid, get_sampler, plot_img_grid,
plot_spect_grid, set_deterministic_mode,
transform_coordinates, to_onehot, generate_latent_grid)
[docs]class jtrVAE(baseVAE):
"""
Variational autoencoder for learning (jointly) discrete and
continuous latent representations on data with arbitrary rotations
and/or translations
Args:
data_dim:
Dimensionality of the input data; use (h x w) for images
or (length,) for spectra.
latent_dim:
Number of continuous latent dimensions.
discrete_dim:
Number of discrete latent dimensions
coord:
For 2D systems, *coord=0* is vanilla VAE, *coord=1* enforces
rotational invariance, *coord=2* enforces invariance to translations,
and *coord=3* enforces both rotational and translational invariances.
For 1D systems, *coord=0* is vanilla VAE and *coord>0* enforces
transaltional invariance.
hidden_dim_e:
Number of hidden units per each layer in encoder (inference network).
hidden_dim_d:
Number of hidden units per each layer in decoder (generator network).
num_layers_e:
Number of layers in encoder (inference network).
num_layers_d:
Number of layers in decoder (generator network).
activation:
Non-linear activation for inner layers of encoder and decoder.
The available activations are ReLU ('relu'), leaky ReLU ('lrelu'),
hyberbolic tangent ('tanh'), and softplus ('softplus')
The default activation is 'tanh'.
sampler_d:
Decoder sampler, as defined as p(x|z) = sampler(decoder(z)).
The available samplers are 'bernoulli', 'continuous_bernoulli',
and 'gaussian' (Default: 'bernoulli').
sigmoid_d:
Sigmoid activation for the decoder output (Default: True)
seed:
Seed used in torch.manual_seed(seed) and
torch.cuda.manual_seed_all(seed)
kwargs:
Additional keyword arguments are *dx_prior* and *dy_prior* for setting
a translational prior(s), and *decoder_sig* for setting sigma
in the decoder's sampler when it is set to "gaussian".
Example:
Initialize a joint VAE model with rotational invariance for 10 discrete classes
>>> data_dim = (28, 28)
>>> ssvae = jtrVAE(data_dim, latent_dim=2, discrete_dim=10, coord=1)
"""
def __init__(self,
data_dim: Tuple[int],
latent_dim: int,
discrete_dim: int,
coord: int = 0,
hidden_dim_e: int = 128,
hidden_dim_d: int = 128,
num_layers_e: int = 2,
num_layers_d: int = 2,
activation: str = "tanh",
sampler_d: str = "bernoulli",
sigmoid_d: bool = True,
seed: int = 1,
**kwargs: float
) -> None:
"""
Initializes trVAE's modules and parameters
"""
super(jtrVAE, self).__init__()
pyro.clear_param_store()
set_deterministic_mode(seed)
self.ndim = len(data_dim)
self.data_dim = data_dim
if self.ndim == 1 and coord > 0:
coord = 1
self.encoder_z = jfcEncoderNet(
data_dim, latent_dim+coord, discrete_dim, hidden_dim_e,
num_layers_e, activation, softplus_out=True)
if coord not in [0, 1, 2, 3]:
raise ValueError("'coord' argument must be 0, 1, 2 or 3")
dnet = sDecoderNet if coord in [1, 2, 3] else fcDecoderNet
self.decoder = dnet(
data_dim, latent_dim, discrete_dim, hidden_dim_d,
num_layers_d, activation, sigmoid_out=sigmoid_d, unflat=False)
self.sampler_d = get_sampler(sampler_d, **kwargs)
self.z_dim = latent_dim + coord
self.coord = coord
self.discrete_dim = discrete_dim
self.grid = generate_grid(data_dim).to(self.device)
dx_pri = tt(kwargs.get("dx_prior", 0.1))
dy_pri = kwargs.get("dy_prior", dx_pri.clone())
t_prior = tt([dx_pri, dy_pri]) if self.ndim == 2 else dx_pri
self.t_prior = t_prior.to(self.device)
self.to(self.device)
[docs] def model(self,
x: torch.Tensor,
**kwargs: float) -> None:
"""
Defines the model p(x|z,c)p(z)p(c)
"""
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
beta = kwargs.get("scale_factor", 1.)
reshape_ = torch.prod(tt(x.shape[1:])).item()
bdim = x.shape[0]
with pyro.plate("data"):
# sample the continuous latent vector from the constant prior distribution
z_loc = x.new_zeros(torch.Size((bdim, self.z_dim)))
z_scale = x.new_ones(torch.Size((bdim, self.z_dim)))
# sample discrete latent vector from the constant prior
alpha = x.new_ones(torch.Size((bdim, self.discrete_dim))) / self.discrete_dim
# sample from prior (value will be sampled by guide when computing ELBO)
with pyro.poutine.scale(scale=beta):
z = pyro.sample("latent_cont", dist.Normal(z_loc, z_scale).to_event(1))
z_disc = pyro.sample("latent_disc", dist.OneHotCategorical(alpha))
# split latent variable into parts for rotation and/or translation
# and image content
if self.coord > 0:
phi, dx, z = self.split_latent(z.repeat(self.discrete_dim, 1))
if torch.sum(dx.abs()) != 0:
dx = (dx * self.t_prior).unsqueeze(1)
# transform coordinate grid
grid = self.grid.expand(bdim*self.discrete_dim, *self.grid.shape)
x_coord_prime = transform_coordinates(grid, phi, dx)
# Continuous and discrete latent variables for the decoder
z = [z, z_disc.reshape(-1, self.discrete_dim) if self.coord > 0 else z_disc]
# decode the latent code z together with the transformed coordinates (if any)
dec_args = (x_coord_prime, z) if self.coord else (z,)
loc = self.decoder(*dec_args)
# score against actual images/spectra
loc = loc.view(*z_disc.shape[:-1], reshape_)
pyro.sample(
"obs", self.sampler_d(loc).to_event(1),
obs=x.view(-1, reshape_))
[docs] def guide(self,
x: torch.Tensor,
**kwargs: float) -> None:
"""
Defines the guide q(z,c|x)
"""
# register PyTorch module `encoder_z` with Pyro
pyro.module("encoder_z", self.encoder_z)
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
beta = kwargs.get("scale_factor", 1.)
with pyro.plate("data"):
# use the encoder to get the parameters used to define q(z,c|x)
z_loc, z_scale, alpha = self.encoder_z(x)
# sample the latent code z
with pyro.poutine.scale(scale=beta):
pyro.sample("latent_cont", dist.Normal(z_loc, z_scale).to_event(1))
pyro.sample("latent_disc", dist.OneHotCategorical(alpha))
[docs] def split_latent(self, zs: torch.Tensor) -> Tuple[torch.Tensor]:
"""
Split latent variable into parts with rotation and/or translation
and image content
"""
if self.ndim == 1:
dx = zs[:, 0:1]
zs = zs[:, 1:]
return None, dx, zs
phi, dx = tt(0), tt(0)
# rotation + translation
if self.coord == 3:
phi = zs[:, 0] # encoded angle
dx = zs[:, 1:3] # translation
zs = zs[:, 3:] # image content
# translation only
elif self.coord == 2:
dx = zs[:, :2]
zs = zs[:, 2:]
# rotation only
elif self.coord == 1:
phi = zs[:, 0]
zs = zs[:, 1:]
return phi, dx, zs
[docs] def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
"""
Encodes data using a trained inference (encoder) network
Args:
x_new:
Data to encode with a trained trVAE. The new data must have
the same dimensions (images height and width or spectra length)
as the one used for training.
kwargs:
Batch size as 'batch_size' (for encoding large volumes of data)
"""
z = self._encode(x_new)
z_loc = z[:, :self.z_dim]
z_scale = z[:, self.z_dim:2*self.z_dim]
alphas = z[:, 2*self.z_dim:]
_, pred_labels = torch.max(alphas, 1)
return z_loc, z_scale, pred_labels
[docs] def decode(self, z: torch.Tensor, y: torch.Tensor, **kwargs: int) -> torch.Tensor:
"""
Decodes a batch of latent coordinates
Args:
z: Latent coordinates (without rotational and translational parts)
y: Classes as one-hot vectors for each sample in z
"""
z = torch.cat([z.to(self.device), y.to(self.device)], -1)
loc = self._decode(z, **kwargs)
return loc.view(-1, *self.data_dim)
[docs] def manifold2d(self, d: int, disc_idx: int = 0, plot: bool = True,
**kwargs: Union[str, int]) -> torch.Tensor:
"""
Plots a learned latent manifold in the image space
Args:
d: Grid size
disc_idx: Discrete dimension for which we plot continuous latent manifolds
plot: Plots the generated manifold (Default: True)
kwargs: Keyword arguments include custom min/max values for grid
boundaries passed as 'z_coord' (e.g. z_coord = [-3, 3, -3, 3])
and plot parameters ('padding', 'padding_value', 'cmap', 'origin', 'ylim')
"""
z_disc = to_onehot(tt(disc_idx).unsqueeze(0), self.discrete_dim)
z, (grid_x, grid_y) = generate_latent_grid(d, **kwargs)
z = z.to(self.device)
z = torch.cat([z, z_disc.repeat(z.shape[0], 1)], dim=-1)
z = [z]
if self.coord:
grid = [self.grid.expand(z[0].shape[0], *self.grid.shape)]
z = grid + z
with torch.no_grad():
loc = self.decoder(*z).cpu()
loc = loc.view(-1, *self.data_dim)
if plot:
if self.ndim == 2:
plot_img_grid(
loc, d,
extent=[grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()],
**kwargs)
elif self.ndim == 1:
plot_spect_grid(loc, d, **kwargs)
return loc