Helper Module for Deep Learning.
Module containing VAE utilities.
Code: https://github.com/YannDubs/disentangling-vae
-
pynet.models.vae.utils.add_labels(input_image, labels)[source]¶ Adds labels next to rows of an image.
- Parameters
input_image: PIL.Image
the image to which to add the labels.
labels: list
the list of labels to plot.
-
pynet.models.vae.utils.get_traversal_range(mean=0, std=1, max_traversal=0.475)[source]¶ Return the corresponding traversal range in absolute terms.
- Parameters
mean: float, default 0
normal distribution mean.
std: float, default 1
normal distribution sigma.
max_traversal: float, default 0.475
the maximum displacement induced by a latent traversal. Symmetrical traversals are assumed. If m >= 0.5 then uses absolute value traversal, if m < 0.5 uses a percentage of the distribution (quantile), e.g. for the prior the distribution is a standard normal so m = 0.45 corresponds to an absolute value of 1.645 because 2m = 90% of a standard normal is between -1.645 and 1.645. Note in the case of the posterior, the distribution is not standard normal anymore.
- Returns
out: 2-uplet
traversal range.
-
pynet.models.vae.utils.make_mosaic_img(arr)[source]¶ Converts a grid of image array into a single mosaic.
- Parameters
arr: numpy.ndarray (ROWS, COLS, C, H, W)
organized images all of the same size to generate the mosaic.
-
pynet.models.vae.utils.reconstruct_traverse(model, data, n_per_latent=8, n_latents=None, is_posterior=False, filename=None)[source]¶ Creates a figure whith first row for original images, second are reconstructions, rest are traversals (prior or posterior) of the latent dimensions.
- Parameters
model: nn.Module
the trained network.
data: torch.Tensor (N, C, H, W)
data to be reconstructed.
n_per_latent: int, default 8
the number of points to include in the traversal of a latent dimension, i.e. the number of columns.
n_latents: int, default None
the number of latent dimensions to display, i.e. the number of rows. If ‘None’ uses all latents.
is_posterior: bool, default False
whether to sample from the posterior.
filename: str, default None
path to save the final image.
-
pynet.models.vae.utils.traversals(model, device, data=None, n_per_latent=8, n_latents=None)[source]¶ Plot traverse through all latent dimensions (prior or posterior) one by one and plots a grid of images where each row corresponds to a latent traversal of one latent dimension.
- Parameters
model: nn.Module
the trained network.
device: torch.device
the device.
data: torch.Tensor (N, C, H, W), default None
data to use for computing the posterior. If ‘None’ then use the mean of the prior (all zeros) for all other dimensions.
n_per_latent: int, default 8
the number of points to include in the traversal of a latent dimension, i.e. the number of columns.
n_latents: int, default None
the number of latent dimensions to display, i.e. the number of rows. If ‘None’ uses all latents.
-
pynet.models.vae.utils.traverse_line(model, idx, n_samples, data=None, max_traversal=0.475)[source]¶ Return latent samples corresponding to a traversal of a latent variable indicated by idx.
- Parameters
model: nn.Module
the trained network.
idx: int
index of continuous dimension to traverse. If the continuous latent vector is 10 dimensional and idx = 7, then the 7th dimension will be traversed while all others are fixed.
n_samples: int
number of samples to generate.
data: torch.Tensor (N, C, H, W), default None
data to use for computing the posterior. If ‘None’ then use the mean of the prior (all zeros) for all other dimensions.
max_traversal: float, default 0.475
the maximum displacement induced by a latent traversal. Symmetrical traversals are assumed. If m >= 0.5 then uses absolute value traversal, if m < 0.5 uses a percentage of the distribution (quantile), e.g. for the prior the distribution is a standard normal so m = 0.45 corresponds to an absolute value of 1.645 because 2m = 90% of a standard normal is between -1.645 and 1.645. Note in the case of the posterior, the distribution is not standard normal anymore.
- Returns
samples: torch.Tensor (n_samples, latent_size)
Follow us
Inspired by AZMIND template.