Source code for pyroved.models.base

"""
base.py
=========

Variational encoder-decoder base class

Created by Maxim Ziatdinov (email: ziatdinovmax@gmail.com)
"""

from typing import Tuple, Type, Union, List
from abc import abstractmethod

import torch
import torch.nn as nn

from ..utils import init_dataloader


class baseVAE(nn.Module):
    """
    Base class for variational autoencoder models
    """
    def __init__(self):
        super(baseVAE, self).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.encoder_z = None
        self.decoder = None
        self.coord = 0
        self.grid = None

    @abstractmethod
    def model(self):
        """Pyro's model"""
        raise NotImplementedError

    @abstractmethod
    def guide(self):
        """Pyro's guide"""
        raise NotImplementedError

    def _encode(self,
                *input_args: Union[torch.Tensor, List[torch.Tensor]],
                **kwargs: int) -> torch.Tensor:
        """
        Encodes data using a trained inference (encoder) network
        in a batch-by-batch fashion
        """
        def inference(x: Tuple[torch.Tensor]) -> torch.Tensor:
            x = torch.cat(x, -1).to(self.device)
            with torch.no_grad():
                encoded = self.encoder_z(x)
            encoded = torch.cat(encoded, -1).cpu()
            return encoded

        loader = init_dataloader(*input_args, shuffle=False, **kwargs)
        z_encoded = []
        for x in loader:
            z_encoded.append(inference(x))
        return torch.cat(z_encoded)

    def _decode(self, z_new: torch.Tensor, **kwargs: int) -> torch.Tensor:
        """
        Decodes latent coordiantes in a batch-by-batch fashion
        """
        def generator(z: List[torch.Tensor]) -> torch.Tensor:
            with torch.no_grad():
                loc = self.decoder(*z)
            return loc.cpu()

        z_new = init_dataloader(z_new, shuffle=False, **kwargs)
        x_decoded = []
        for z in z_new:
            if self.coord > 0:
                z = [self.grid.expand(z[0].shape[0], *self.grid.shape)] + z
            x_decoded.append(generator(z))
        return torch.cat(x_decoded)

    def set_encoder(self, encoder_net: Type[torch.nn.Module]) -> None:
        """
        Sets a user-defined encoder neural network
        """
        self.encoder_z = encoder_net.to(self.device)

    def set_decoder(self, decoder_net: Type[torch.nn.Module]) -> None:
        """
        Sets a user-defined decoder neural network
        """
        self.decoder = decoder_net.to(self.device)

    def save_weights(self, filepath: str) -> None:
        """
        Saves trained weights of encoder(s) and decoder
        """
        torch.save(self.state_dict(), filepath + '.pt')

    def load_weights(self, filepath: str) -> None:
        """
        Loads saved weights of encoder(s) and decoder
        """
        weights = torch.load(filepath, map_location=self.device)
        self.load_state_dict(weights)