"""
ved.py
=========
Variational encoder-decoder model (input and output are different)
Created by Maxim Ziatdinov (email: ziatdinovmax@gmail.com)
"""
from typing import Tuple, Union, List
import pyro
import pyro.distributions as dist
import torch
from .base import baseVAE
from ..nets import convEncoderNet, convDecoderNet
from ..utils import (generate_latent_grid, get_sampler,
init_dataloader, plot_img_grid, plot_spect_grid,
set_deterministic_mode)
[docs]class VED(baseVAE):
"""
Variational encoder-decoder model where the inputs and outputs are not identical.
This model can be used for realizing im2spec and spec2im type of models where
1D spectra are predicted from image data and vice versa.
Args:
input_dim:
Dimensionality of the input data; use (h x w) for images
or (length,) for spectra.
output_dim:
Dimensionality of the input data; use (h x w) for images
or (length,) for spectra. Doesn't have to match the input data.
input_channels:
Number of input channels (Default: 1)
output_channels:
Number of output channels (Default: 1)
latent_dim:
Number of latent dimensions.
hidden_dim_e:
Number of hidden units (convolutional filters) for each layer in
the first block of the encoder NN. The number of units in the
consecutive blocks is defined as hidden_dim_e * n,
where n = 2, 3, ..., n_blocks (Default: 32).
hidden_dim_e:
Number of hidden units (convolutional filters) for each layer in
the first block of the decoder NN. The number of units in the
consecutive blocks is defined as hidden_dim_e // n,
where n = 2, 3, ..., n_blocks (Default: 32).
num_layers_e:
List with numbers of layers per each block of the encoder NN.
Defaults to [1, 2, 2] if none is specified.
num_layers_d:
List with numbers of layers per each block of the decoder NN.
Defaults to [2, 2, 1] if none is specified.
activation:
Non-linear activation for inner layers of encoder and decoder.
The available activations are ReLU ('relu'), leaky ReLU ('lrelu'),
hyberbolic tangent ('tanh'), and softplus ('softplus')
The default activation is 'tanh'.
batchnorm:
Batch normalization attached to each convolutional layer
after non-linear activation (except for layers with 1x1 filters)
in the encoder and decoder NNs (Default: False)
sampler_d:
Decoder sampler, as defined as p(x|z) = sampler(decoder(z)).
The available samplers are 'bernoulli', 'continuous_bernoulli',
and 'gaussian' (Default: 'bernoulli').
sigmoid_d:
Sigmoid activation for the decoder output (Default: True)
seed:
Seed used in torch.manual_seed(seed) and
torch.cuda.manual_seed_all(seed)
kwargs:
Additional keyword argument is *decoder_sig* for setting sigma
in the decoder's sampler when it is chosen to be a "gaussian".
Example:
Initialize a VED model for predicting 1D spectra from 2D images
>>> input_dim = (32, 32) # image height and width
>>> output_dim = (16,) # spectrum length
>>> ved = VED(input_dim, output_dim, latent_dim=2)
"""
def __init__(self,
input_dim: Tuple[int],
output_dim: Tuple[int],
input_channels: int = 1,
output_channels: int = 1,
latent_dim: int = 2,
hidden_dim_e: int = 32,
hidden_dim_d: int = 96,
num_layers_e: List[int] = None,
num_layers_d: List[int] = None,
activation: str = "lrelu",
batchnorm: bool = False,
sampler_d: str = "bernoulli",
sigmoid_d: bool = True,
seed: int = 1,
**kwargs: float
) -> None:
"""
Initializes VED's modules and parameters
"""
super(VED, self).__init__()
pyro.clear_param_store()
set_deterministic_mode(seed)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.ndim = len(output_dim)
self.encoder_z = convEncoderNet(
input_dim, input_channels, latent_dim,
num_layers_e, hidden_dim_e,
batchnorm, activation)
self.decoder = convDecoderNet(
latent_dim, output_dim, output_channels,
num_layers_d, hidden_dim_d,
batchnorm, activation, sigmoid_d)
self.sampler_d = get_sampler(sampler_d, **kwargs)
self.z_dim = latent_dim
self.to(self.device)
[docs] def model(self,
x: torch.Tensor = None,
y: torch.Tensor = None,
**kwargs: float) -> None:
"""
Defines the model p(y|z)p(z)
"""
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
beta = kwargs.get("scale_factor", 1.)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
# sample from prior (value will be sampled by guide when computing the ELBO)
with pyro.poutine.scale(scale=beta):
z = pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc = self.decoder(z)
# score against actual images
pyro.sample(
"obs", self.sampler_d(loc.flatten(1)).to_event(1),
obs=y.flatten(1))
[docs] def guide(self,
x: torch.Tensor = None,
y: torch.Tensor = None,
**kwargs: float) -> None:
"""
Defines the guide q(z|x)
"""
# register PyTorch module `encoder_z` with Pyro
pyro.module("encoder_z", self.encoder_z)
# KLD scale factor (see e.g. https://openreview.net/pdf?id=Sy2fzU9gl)
beta = kwargs.get("scale_factor", 1.)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder_z(x)
# sample the latent code z
with pyro.poutine.scale(scale=beta):
pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1))
[docs] def encode(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
"""
Encodes data using a trained inference (encoder) network
Args:
x_new:
Data to encode with a trained trVAE. The new data must have
the same dimensions (images height and width or spectra length)
as the one used for training.
kwargs:
Batch size as 'batch_size' (for encoding large volumes of data)
"""
self.eval()
z = self._encode(x_new)
z_loc, z_scale = z.split(self.z_dim, 1)
return z_loc, z_scale
[docs] def decode(self,
z: torch.Tensor,
**kwargs: int) -> torch.Tensor:
"""
Decodes a batch of latent coordnates
Args:
z: Latent coordinates
"""
self.eval()
z = z.to(self.device)
loc = self._decode(z, **kwargs)
return loc
[docs] def predict(self, x_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
"""Forward prediction (encode -> sample -> decode)"""
def forward_(x_i) -> torch.Tensor:
with torch.no_grad():
encoded = self.encoder_z(x_i)
encoded = torch.cat(encoded, -1)
z_mu, z_sig = encoded.split(self.z_dim, 1)
z_samples = dist.Normal(z_mu, z_sig).rsample(sample_shape=(30,))
y = torch.cat([self.decoder(z)[None] for z in z_samples])
return y.mean(0).cpu(), y.std(0).cpu()
x_new = init_dataloader(x_new, shuffle=False, **kwargs)
prediction_mu, prediction_sd = [], []
for (x_i,) in x_new:
y_mu, y_sd = forward_(x_i.to(self.device))
prediction_mu.append(y_mu)
prediction_sd.append(y_sd)
return torch.cat(prediction_mu), torch.cat(prediction_sd)
[docs] def manifold2d(self, d: int, plot: bool = True,
**kwargs: Union[str, int]) -> torch.Tensor:
"""
Plots a learned latent manifold in the image space
Args:
d: Grid size
plot: Plots the generated manifold (Default: True)
kwargs: Keyword arguments include custom min/max values for grid
boundaries passed as 'z_coord' (e.g. z_coord = [-3, 3, -3, 3])
and plot parameters ('padding', 'padding_value', 'cmap', 'origin', 'ylim')
"""
self.eval()
z, (grid_x, grid_y) = generate_latent_grid(d, **kwargs)
z = z.to(self.device)
with torch.no_grad():
loc = self.decoder(z).cpu()
if plot:
if self.ndim == 2:
plot_img_grid(
loc, d,
extent=[grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()],
**kwargs)
elif self.ndim == 1:
plot_spect_grid(loc, d, **kwargs)
return loc