Menu

Helper Module for Deep Learning.

Module containing VAE utilities.

Code: https://github.com/YannDubs/disentangling-vae

pynet.models.vae.utils.add_labels(input_image, labels)[source]

Adds labels next to rows of an image.

Parameters

input_image: PIL.Image

the image to which to add the labels.

labels: list

the list of labels to plot.

pynet.models.vae.utils.get_traversal_range(mean=0, std=1, max_traversal=0.475)[source]

Return the corresponding traversal range in absolute terms.

Parameters

mean: float, default 0

normal distribution mean.

std: float, default 1

normal distribution sigma.

max_traversal: float, default 0.475

the maximum displacement induced by a latent traversal. Symmetrical traversals are assumed. If m >= 0.5 then uses absolute value traversal, if m < 0.5 uses a percentage of the distribution (quantile), e.g. for the prior the distribution is a standard normal so m = 0.45 corresponds to an absolute value of 1.645 because 2m = 90% of a standard normal is between -1.645 and 1.645. Note in the case of the posterior, the distribution is not standard normal anymore.

Returns

out: 2-uplet

traversal range.

pynet.models.vae.utils.make_mosaic_img(arr)[source]

Converts a grid of image array into a single mosaic.

Parameters

arr: numpy.ndarray (ROWS, COLS, C, H, W)

organized images all of the same size to generate the mosaic.

pynet.models.vae.utils.reconstruct_traverse(model, data, n_per_latent=8, n_latents=None, is_posterior=False, filename=None)[source]

Creates a figure whith first row for original images, second are reconstructions, rest are traversals (prior or posterior) of the latent dimensions.

Parameters

model: nn.Module

the trained network.

data: torch.Tensor (N, C, H, W)

data to be reconstructed.

n_per_latent: int, default 8

the number of points to include in the traversal of a latent dimension, i.e. the number of columns.

n_latents: int, default None

the number of latent dimensions to display, i.e. the number of rows. If ‘None’ uses all latents.

is_posterior: bool, default False

whether to sample from the posterior.

filename: str, default None

path to save the final image.

pynet.models.vae.utils.traversals(model, device, data=None, n_per_latent=8, n_latents=None)[source]

Plot traverse through all latent dimensions (prior or posterior) one by one and plots a grid of images where each row corresponds to a latent traversal of one latent dimension.

Parameters

model: nn.Module

the trained network.

device: torch.device

the device.

data: torch.Tensor (N, C, H, W), default None

data to use for computing the posterior. If ‘None’ then use the mean of the prior (all zeros) for all other dimensions.

n_per_latent: int, default 8

the number of points to include in the traversal of a latent dimension, i.e. the number of columns.

n_latents: int, default None

the number of latent dimensions to display, i.e. the number of rows. If ‘None’ uses all latents.

pynet.models.vae.utils.traverse_line(model, idx, n_samples, data=None, max_traversal=0.475)[source]

Return latent samples corresponding to a traversal of a latent variable indicated by idx.

Parameters

model: nn.Module

the trained network.

idx: int

index of continuous dimension to traverse. If the continuous latent vector is 10 dimensional and idx = 7, then the 7th dimension will be traversed while all others are fixed.

n_samples: int

number of samples to generate.

data: torch.Tensor (N, C, H, W), default None

data to use for computing the posterior. If ‘None’ then use the mean of the prior (all zeros) for all other dimensions.

max_traversal: float, default 0.475

the maximum displacement induced by a latent traversal. Symmetrical traversals are assumed. If m >= 0.5 then uses absolute value traversal, if m < 0.5 uses a percentage of the distribution (quantile), e.g. for the prior the distribution is a standard normal so m = 0.45 corresponds to an absolute value of 1.645 because 2m = 90% of a standard normal is between -1.645 and 1.645. Note in the case of the posterior, the distribution is not standard normal anymore.

Returns

samples: torch.Tensor (n_samples, latent_size)

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.