"""
sstrvae.py
=========
Semi-supervised variational autoencoder for data
with positional (rotation+translation) disorder
Created by Maxim Ziatdinov (email: ziatdinovmax@gmail.com)
"""
from typing import Optional, Tuple, Union, Type
import pyro
import pyro.distributions as dist
import torch
import torch.tensor as tt
from .base import baseVAE
from ..nets import fcDecoderNet, fcEncoderNet, sDecoderNet, fcClassifierNet
from ..utils import (generate_grid, get_sampler, plot_img_grid,
plot_spect_grid, set_deterministic_mode, to_onehot,
transform_coordinates, init_dataloader, generate_latent_grid)
[docs]class sstrVAE(baseVAE):
"""
Semi-supervised variational autoencoder with rotational and/or translational invariance
Args:
data_dim:
Dimensionality of the input data; use (h x w) for images
or (length,) for spectra.
latent_dim:
Number of latent dimensions.
num_classes:
Number of classes in the classification scheme
coord:
For 2D systems, *coord=0* is vanilla VAE, *coord=1* enforces
rotational invariance, *coord=2* enforces invariance to translations,
and *coord=3* enforces both rotational and translational invariances.
For 1D systems, *coord=0* is vanilla VAE and *coord>0* enforces
transaltional invariance.
hidden_dim_e:
Number of hidden units per each layer in encoder (inference network).
hidden_dim_d:
Number of hidden units per each layer in decoder (generator network).
hidden_dim_cls:
Number of hidden units ("neurons") in each layer of classifier
num_layers_e:
Number of layers in encoder (inference network).
num_layers_d:
Number of layers in decoder (generator network).
num_layers_cls:
Number of layers in classifier
sampler_d:
Decoder sampler, as defined as p(x|z) = sampler(decoder(z)).
The available samplers are 'bernoulli', 'continuous_bernoulli',
and 'gaussian' (Default: 'bernoulli').
sigmoid_d:
Sigmoid activation for the decoder output (Default: True)
seed:
Seed used in torch.manual_seed(seed) and
torch.cuda.manual_seed_all(seed)
kwargs:
Additional keyword arguments are *dx_prior* and *dy_prior* for setting
a translational prior(s), and *decoder_sig* for setting sigma
in the decoder's sampler when it is set to "gaussian".
Example:
Initialize a VAE model with rotational invariance for
semisupervised learning of the dataset that has 10 classes
>>> data_dim = (28, 28)
>>> ssvae = sstrVAE(data_dim, latent_dim=2, num_classes=10, coord=1)
"""
def __init__(self,
data_dim: Tuple[int],
latent_dim: int,
num_classes: int,
coord: int = 3,
hidden_dim_e: int = 128,
hidden_dim_d: int = 128,
hidden_dim_cls: int = 128,
num_layers_e: int = 2,
num_layers_d: int = 2,
num_layers_cls: int = 2,
sampler_d: str = "bernoulli",
sigmoid_d: bool = True,
seed: int = 1,
**kwargs: float
) -> None:
"""
Initializes sstrVAE parameters
"""
super(sstrVAE, self).__init__()
pyro.clear_param_store()
set_deterministic_mode(seed)
if coord not in [0, 1, 2, 3]:
raise ValueError("'coord' argument must be 0, 1, 2 or 3")
self.ndim = len(data_dim)
if self.ndim == 1 and coord > 0:
coord = 1
self.data_dim = data_dim
self.encoder_z = fcEncoderNet(
data_dim, latent_dim+coord, num_classes,
hidden_dim_e, num_layers_e, flat=False)
self.encoder_y = fcClassifierNet(
data_dim, num_classes, hidden_dim_cls, num_layers_cls)
dnet = sDecoderNet if coord in [1, 2, 3] else fcDecoderNet
self.decoder = dnet(
data_dim, latent_dim, num_classes, hidden_dim_d,
num_layers_d, sigmoid_out=sigmoid_d, unflat=False)
self.sampler_d = get_sampler(sampler_d, **kwargs)
self.z_dim = latent_dim + coord
self.num_classes = num_classes
self.coord = coord
self.grid = generate_grid(data_dim).to(self.device)
dx_pri = tt(kwargs.get("dx_prior", 0.1))
dy_pri = kwargs.get("dy_prior", dx_pri.clone())
t_prior = tt([dx_pri, dy_pri]) if self.ndim == 2 else dx_pri
self.t_prior = t_prior.to(self.device)
self.to(self.device)
[docs] def model(self,
xs: torch.Tensor,
ys: Optional[torch.Tensor] = None,
**kwargs: float) -> None:
"""
Model of the generative process p(x|z,y)p(y)p(z)
"""
pyro.module("ss_vae", self)
batch_dim = xs.size(0)
specs = dict(dtype=xs.dtype, device=xs.device)
beta = kwargs.get("scale_factor", 1.)
# pyro.plate enforces independence between variables in batches xs, ys
with pyro.plate("data"):
# sample the latent vector from the constant prior distribution
prior_loc = torch.zeros(batch_dim, self.z_dim, **specs)
prior_scale = torch.ones(batch_dim, self.z_dim, **specs)
with pyro.poutine.scale(scale=beta):
zs = pyro.sample(
"z", dist.Normal(prior_loc, prior_scale).to_event(1))
# split latent variable into parts for rotation and/or translation
# and image content
if self.coord > 0:
phi, dx, zs = self.split_latent(zs)
if torch.sum(dx.abs()) != 0:
dx = (dx * self.t_prior).unsqueeze(1)
# transform coordinate grid
if self.ndim > 1:
expdim = dx.shape[0] if self.coord > 1 else phi.shape[0]
else:
expdim = dx.shape[0]
grid = self.grid.expand(expdim, *self.grid.shape)
x_coord_prime = transform_coordinates(grid, phi, dx)
# sample label from the constant prior or observe the value
alpha_prior = (torch.ones(batch_dim, self.num_classes, **specs) /
self.num_classes)
ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)
# Score against the parametrized distribution
# p(x|y,z) = bernoulli(decoder(y,z))
d_args = (x_coord_prime, [zs, ys]) if self.coord else ([zs, ys],)
loc = self.decoder(*d_args)
loc = loc.view(*ys.shape[:-1], -1)
pyro.sample("x", self.sampler_d(loc).to_event(1), obs=xs)
[docs] def guide(self, xs: torch.Tensor,
ys: Optional[torch.Tensor] = None,
**kwargs: float) -> None:
"""
Guide q(z|y,x)q(y|x)
"""
beta = kwargs.get("scale_factor", 1.)
with pyro.plate("data"):
# sample and score the digit with the variational distribution
# q(y|x) = categorical(alpha(x))
if ys is None:
alpha = self.encoder_y(xs)
ys = pyro.sample("y", dist.OneHotCategorical(alpha))
# sample (and score) the latent vector with the variational
# distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
loc, scale = self.encoder_z([xs, ys])
with pyro.poutine.scale(scale=beta):
pyro.sample("z", dist.Normal(loc, scale).to_event(1))
[docs] def split_latent(self, zs: torch.Tensor) -> Tuple[torch.Tensor]:
"""
Split latent variable into parts with rotation and/or translation
and image content
"""
zdims = list(zs.shape)
zdims[-1] = zdims[-1] - self.coord
zs = zs.view(-1, zs.size(-1))
# For 1D, there is only translation
if self.ndim == 1:
dx = zs[:, 0:1]
zs = zs[:, 1:]
return None, dx, zs.view(*zdims)
phi, dx = tt(0), tt(0)
# rotation + translation
if self.coord == 3:
phi = zs[:, 0] # encoded angle
dx = zs[:, 1:3] # translation
zs = zs[:, 3:] # image content
# translation only
elif self.coord == 2:
dx = zs[:, :2]
zs = zs[:, 2:]
# rotation only
elif self.coord == 1:
phi = zs[:, 0]
zs = zs[:, 1:]
zs = zs.view(*zdims)
return phi, dx, zs
[docs] def model_classify(self, xs: torch.Tensor,
ys: Optional[torch.Tensor] = None,
**kwargs: float) -> None:
"""
Models an auxiliary (supervised) loss
"""
pyro.module("ss_vae", self)
with pyro.plate("data"):
# the extra term to yield an auxiliary loss
aux_loss_multiplier = kwargs.get("aux_loss_multiplier", 20)
if ys is not None:
alpha = self.encoder_y.forward(xs)
with pyro.poutine.scale(scale=aux_loss_multiplier):
pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)
[docs] def guide_classify(self, xs, ys=None, **kwargs):
"""
Dummy guide function to accompany model_classify
"""
pass
[docs] def set_classifier(self, cls_net: Type[torch.nn.Module]) -> None:
"""
Sets a user-defined classification network
"""
self.encoder_y = cls_net
[docs] def classifier(self,
x_new: torch.Tensor,
**kwargs: int) -> torch.Tensor:
"""
Classifies data
Args:
x_new:
Data to classify with a trained ss-trVAE. The new data must have
the same dimensions (images height x width or spectra length)
as the one used for training.
kwargs:
Batch size as 'batch_size' (for encoding large volumes of data)
"""
def classify(x_i) -> torch.Tensor:
with torch.no_grad():
alpha = self.encoder_y(x_i)
_, predicted = torch.max(alpha.data, 1)
return predicted.cpu()
x_new = init_dataloader(x_new, shuffle=False, **kwargs)
y_predicted = []
for (x_i,) in x_new:
y_predicted.append(classify(x_i.to(self.device)))
return torch.cat(y_predicted)
[docs] def encode(self,
x_new: torch.Tensor,
y: Optional[torch.Tensor] = None,
**kwargs: int) -> torch.Tensor:
"""
Encodes data using a trained inference (encoder) network
Args:
x_new:
Data to encode with a trained trVAE. The new data must have
the same dimensions (images height and width or spectra length)
as the one used for training.
y:
Classes as one-hot vectors for each sample in x_new. If not provided,
the ss-trVAE's classifier will be used to predict the classes.
kwargs:
Batch size as 'batch_size' (for encoding large volumes of data)
"""
if y is None:
y = self.classifier(x_new, **kwargs)
if y.ndim < 2:
y = to_onehot(y, self.num_classes)
z = self._encode(x_new, y, **kwargs)
z_loc, z_scale = z.split(self.z_dim, 1)
_, y_pred = torch.max(y, 1)
return z_loc, z_scale, y_pred
[docs] def decode(self, z: torch.Tensor, y: torch.Tensor, **kwargs: int) -> torch.Tensor:
"""
Decodes a batch of latent coordinates
Args:
z: Latent coordinates (without rotational and translational parts)
y: Classes as one-hot vectors for each sample in z
kwargs: Batch size as 'batch_size'
"""
z = torch.cat([z.to(self.device), y.to(self.device)], -1)
loc = self._decode(z, **kwargs)
return loc.view(-1, *self.data_dim)
[docs] def manifold2d(self, d: int, plot: bool = True,
**kwargs: Union[str, int]) -> torch.Tensor:
"""
Returns a learned latent manifold in the image space
Args:
d: Grid size
plot: Plots the generated manifold (Default: True)
kwargs: Keyword arguments include 'label' for class label (if any),
custom min/max values for grid boundaries passed as 'z_coord'
(e.g. z_coord = [-3, 3, -3, 3]) and plot parameters
('padding', 'padding_value', 'cmap', 'origin', 'ylim')
"""
cls = tt(kwargs.get("label", 0))
if cls.ndim < 2:
cls = to_onehot(cls.unsqueeze(0), self.num_classes)
z, (grid_x, grid_y) = generate_latent_grid(d, **kwargs)
z = z.to(self.device)
z = torch.cat([z, cls.repeat(z.shape[0], 1)], dim=-1)
z = [z]
if self.coord:
grid = [self.grid.expand(z[0].shape[0], *self.grid.shape)]
z = grid + z
with torch.no_grad():
loc = self.decoder(*z).cpu()
loc = loc.view(-1, *self.data_dim)
if plot:
if self.ndim == 2:
plot_img_grid(
loc, d,
extent=[grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()],
**kwargs)
elif self.ndim == 1:
plot_spect_grid(loc, d, **kwargs)
return loc