from typing import Union, Tuple
import torch
import torch.tensor as tt
import pyro.distributions as dist
[docs]def generate_grid(data_dim: Tuple[int]) -> torch.Tensor:
"""
Generates 1D or 2D grid of coordinates
"""
if len(data_dim) not in [1, 2]:
raise NotImplementedError("Currently supports only 1D and 2D data")
if len(data_dim) == 1:
return torch.linspace(-1, 1, data_dim[0])[:, None]
return imcoordgrid(data_dim)
def grid2xy(X1: torch.Tensor, X2: torch.Tensor) -> torch.Tensor:
X = torch.cat((X1[None], X2[None]), 0)
d0, d1 = X.shape[0], X.shape[1] * X.shape[2]
X = X.reshape(d0, d1).T
return X
def imcoordgrid(im_dim: Tuple[int]) -> torch.Tensor:
xx = torch.linspace(-1, 1, im_dim[0])
yy = torch.linspace(1, -1, im_dim[1])
x0, x1 = torch.meshgrid(xx, yy)
return grid2xy(x0, x1)
[docs]def generate_latent_grid(d: int, **kwargs) -> torch.Tensor:
"""
Generates a grid of latent space coordinates
"""
z_coord = kwargs.get("z_coord")
if z_coord:
z1, z2, z3, z4 = z_coord
grid_x = torch.linspace(z2, z1, d)
grid_y = torch.linspace(z3, z4, d)
else:
grid_x = dist.Normal(0, 1).icdf(torch.linspace(0.95, 0.05, d))
grid_y = dist.Normal(0, 1).icdf(torch.linspace(0.05, 0.95, d))
z = []
for xi in grid_x:
for yi in grid_y:
z.append(tt([xi, yi]).float().unsqueeze(0))
return torch.cat(z), (grid_x, grid_y)