Source code for pyroved.nets.fc

"""
fc.py

Module for creating fully-connected encoder and decoder modules

Created by Maxim Ziatdinov (ziatdinovmax@gmail.com)
"""

from typing import List, Tuple, Type, Union

import torch
import torch.nn as nn
import torch.tensor as tt
from pyro.distributions.util import broadcast_shape

from ..utils import get_activation


class Concat(nn.Module):
    """
    Module for concatenation of tensors
    """
    def __init__(self, allow_broadcast: bool = True):
        """
        Initializes module
        """
        self.allow_broadcast = allow_broadcast
        super().__init__()

    def forward(self, input_args: Union[List[torch.Tensor], torch.Tensor]
                ) -> torch.Tensor:
        """
        Performs concatenation
        """
        if torch.is_tensor(input_args):
            return input_args
        if self.allow_broadcast:
            shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,)
            input_args = [s.expand(shape) for s in input_args]
        out = torch.cat(input_args, dim=-1)
        return out


[docs]class fcEncoderNet(nn.Module): """ Standard fully-connected encoder NN for VAE. The encoder outputs mean and standard evidation of the encoded distribution. """ def __init__(self, in_dim: Tuple[int], latent_dim: int = 2, num_classes: int = 0, hidden_dim: int = 128, num_layers: int = 2, activation: str = 'tanh', softplus_out: bool = True, flat: bool = True ) -> None: """ Initializes module """ super(fcEncoderNet, self).__init__() if len(in_dim) not in [1, 2, 3]: raise ValueError("in_dim must be (h, w), (h, w, c), or (l,)") self.in_dim = torch.prod(tt(in_dim)).item() + num_classes self.flat = flat self.concat = Concat() self.fc_layers = make_fc_layers( self.in_dim, hidden_dim, num_layers, activation) self.fc11 = nn.Linear(hidden_dim, latent_dim) self.fc12 = nn.Linear(hidden_dim, latent_dim) self.activation_out = nn.Softplus() if softplus_out else lambda x: x def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: """ Forward pass """ x = self.concat(x) if self.flat: x = x.view(-1, self.in_dim) x = self.fc_layers(x) mu = self.fc11(x) sigma = self.activation_out(self.fc12(x)) return mu, sigma
[docs]class jfcEncoderNet(nn.Module): """ Fully-connected encoder for joint VAE. The encoder outputs mean, standard evidation and class probabilities. """ def __init__(self, in_dim: Tuple[int], latent_dim: int = 2, discrete_dim: int = 0, hidden_dim: int = 128, num_layers: int = 2, activation: str = 'tanh', softplus_out: bool = True, flat: bool = True ) -> None: """ Initializes module """ super(jfcEncoderNet, self).__init__() if len(in_dim) not in [1, 2, 3]: raise ValueError("in_dim must be (h, w), (h, w, c), or (l,)") self.in_dim = torch.prod(tt(in_dim)).item() self.flat = flat self.concat = Concat() self.fc_layers = make_fc_layers( self.in_dim, hidden_dim, num_layers, activation) self.fc11 = nn.Linear(hidden_dim, latent_dim) self.fc12 = nn.Linear(hidden_dim, latent_dim) self.fc13 = nn.Linear(hidden_dim, discrete_dim) self.activation_out = nn.Softplus() if softplus_out else lambda x: x def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: """ Forward pass """ x = self.concat(x) if self.flat: x = x.view(-1, self.in_dim) x = self.fc_layers(x) mu = self.fc11(x) sigma = self.activation_out(self.fc12(x)) alpha = torch.softmax(self.fc13(x), dim=-1) return mu, sigma, alpha
[docs]class fcDecoderNet(nn.Module): """ Standard fully-connected decoder for VAE """ def __init__(self, out_dim: Tuple[int], latent_dim: int, num_classes: int = 0, hidden_dim: int = 128, num_layers: int = 2, activation: str = 'tanh', sigmoid_out: bool = True, unflat: bool = True ) -> None: """ Initializes module """ super(fcDecoderNet, self).__init__() if len(out_dim) not in [1, 2, 3]: raise ValueError("in_dim must be (h, w), (h, w, c), or (l,)") self.unflat = unflat if self.unflat: self.reshape = out_dim out_dim = torch.prod(tt(out_dim)).item() self.concat = Concat() self.fc_layers = make_fc_layers( latent_dim+num_classes, hidden_dim, num_layers, activation) self.out = nn.Linear(hidden_dim, out_dim) self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x def forward(self, z: torch.Tensor) -> torch.Tensor: """ Forward pass """ z = self.concat(z) x = self.fc_layers(z) x = self.activation_out(self.out(x)) if self.unflat: return x.view(-1, *self.reshape) return x
[docs]class sDecoderNet(nn.Module): """ Spatial generator (decoder) network with fully-connected layers """ def __init__(self, out_dim: Tuple[int], latent_dim: int, num_classes: int = 0, hidden_dim: int = 128, num_layers: int = 2, activation: str = 'tanh', sigmoid_out: bool = True, unflat: bool = True ) -> None: """ Initializes module """ super(sDecoderNet, self).__init__() if len(out_dim) not in [1, 2, 3]: raise ValueError("in_dim must be (h, w), (h, w, c), or (l,)") self.unflat = unflat if self.unflat: self.reshape = out_dim coord_dim = 1 if len(out_dim) < 2 else 2 self.concat = Concat() self.coord_latent = coord_latent( latent_dim+num_classes, hidden_dim, coord_dim) self.fc_layers = make_fc_layers( hidden_dim, hidden_dim, num_layers, activation) self.out = nn.Linear(hidden_dim, 1) # need to generalize to multi-channel (c > 1) self.activation_out = nn.Sigmoid() if sigmoid_out else lambda x: x def forward(self, x_coord: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ Forward pass """ z = self.concat(z) x = self.coord_latent(x_coord, z) x = self.fc_layers(x) x = self.activation_out(self.out(x)) if self.unflat: return x.view(-1, *self.reshape) return x
class coord_latent(nn.Module): """ The "spatial" part of the trVAE's decoder that allows for translational and rotational invariance (based on https://arxiv.org/abs/1909.11663) """ def __init__(self, latent_dim: int, out_dim: int, ndim: int = 2, activation_out: bool = True) -> None: """ Initializes module """ super(coord_latent, self).__init__() self.fc_coord = nn.Linear(ndim, out_dim) self.fc_latent = nn.Linear(latent_dim, out_dim, bias=False) self.activation = nn.Tanh() if activation_out else None def forward(self, x_coord: torch.Tensor, z: Tuple[torch.Tensor]) -> torch.Tensor: batch_dim, n = x_coord.size()[:2] x_coord = x_coord.reshape(batch_dim * n, -1) h_x = self.fc_coord(x_coord) h_x = h_x.reshape(batch_dim, n, -1) h_z = self.fc_latent(z) h_z = h_z.view(-1, h_z.size(-1)) h = h_x.add(h_z.unsqueeze(1)) h = h.reshape(batch_dim * n, -1) if self.activation is not None: h = self.activation(h) return h class fcClassifierNet(nn.Module): """ Simple classification neural network with fully-connected layers only. """ def __init__(self, in_dim: Tuple[int], num_classes: int, hidden_dim: int = 128, num_layers: int = 2, activation: str = 'tanh' ) -> None: """ Initializes module """ super(fcClassifierNet, self).__init__() if len(in_dim) not in [1, 2, 3]: raise ValueError("in_dim must be (h, w), (h, w, c), or (l,)") self.in_dim = torch.prod(tt(in_dim)).item() self.fc_layers = make_fc_layers( self.in_dim, hidden_dim, num_layers, activation) self.out = nn.Linear(hidden_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass """ x = self.fc_layers(x) x = self.out(x) return torch.softmax(x, dim=-1) def make_fc_layers(in_dim: int, hidden_dim: int = 128, num_layers: int = 2, activation: str = "tanh" ) -> Type[nn.Module]: """ Generates a module with stacked fully-connected (aka dense) layers """ fc_layers = [] for i in range(num_layers): hidden_dim_ = in_dim if i == 0 else hidden_dim fc_layers.extend( [nn.Linear(hidden_dim_, hidden_dim), get_activation(activation)()]) fc_layers = nn.Sequential(*fc_layers) return fc_layers