Helper Module for Deep Learning.
Gaussian Mixture Variational Auto-Encoder (GMVAE).
Two implementations are proposed:
VAEGMP is an adaptation of VAE to make use of a Gaussian Mixture prior, instead of a standard Normal distribution.
GMVAE is an attempt to replicate the work described in [1] and [2]
[1] Gaussian Mixture VAE: Lessons in Variational Inference, Generative Models, and Deep Nets: http://ruishu.io/2016/12/25/gmvae [2] Deep Unsupervised Clustering with Gaussian Mixture Variational Autoencoders Nat Dilokthanakul, arXiv 2017. Code: https://github.com/jariasf/GMVAE Code: https://github.com/mazrk7/gmvae
-
class
pynet.models.vae.gmvae.GMVAENet(input_dim, latent_dim, n_mix_components, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, dropout=0, temperature=1, gen_bias_init=0.0, prior_gmm=None, decoder=None, encoder_y=None, encoder_gmm=None, random_seed=None)[source]¶ The Gaussian Mixture VAE architecture.
Meta-GMVAE: Mixture of Gaussian VAE for Unsupervised Meta-Learning Dong Bok Lee, ICLR 2021.
Gaussian Mixture VAE: Lessons in Variational Inference, Generative Models, and Deep Nets: http://ruishu.io/2016/12/25/gmvae
Deep Unsupervised Clustering with Gaussian Mixture Variational Autoencoders Nat Dilokthanakul, arXiv 2017.
-
__init__(input_dim, latent_dim, n_mix_components, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, dropout=0, temperature=1, gen_bias_init=0.0, prior_gmm=None, decoder=None, encoder_y=None, encoder_gmm=None, random_seed=None)[source]¶ Init class.
- Parameters
input_dim: int
the input size.
latent_dim: int,
the size of the stochastic latent state of the GMVAE.
n_mix_components: int
the number of mixture components.
dense_hidden_dims: list of int, default None
the sizes of the hidden layers of the fully connected network used to condition the distribution on the inputs. If None, then the default is a single-layered dense network.
sigma_min: float, default 0.001
the minimum value that the standard deviation of the distribution over the latent state can take.
raw_sigma_bias: float, default 0.25
a scalar that is added to the raw standard deviation output from the neural networks that parameterize the prior and approximate posterior. Useful for preventing standard deviations close to zero.
dropout: float, default 0
define the dropout rate.
temperature: float, default 1
degree of how approximately discrete the distribution is. The closer to 0, the more discrete and the closer to infinity, the more uniform.
gen_bias_init: float, default 0
a bias to added to the raw output of the fully connected network that parameterizes the generative distribution. Useful for initalising the mean to a sensible starting point e.g. mean of training set.
prior_gmm: @callable, default None
a callable that implements the prior distribution p(z | y) Must accept as argument the y discrete variable and return a tf.distributions.MultivariateNormalDiag distribution.
decoder: : @callable, default None
a callable that implements the generative distribution p(x | z). Must accept as arguments the encoded latent state z and return a subclass of tf.distributions.Distribution that can be used to evaluate the log_prob of the targets.
encoder_y: : @callable, default None
a callable that implements the inference q(y | x) over the discrete latent variable y.
encoder_gmm: : @callable, default None
a callable that implements the inference q(z | x, y) over the continuous latent variable z.
random_seed: int, default None
the seed for the random operations.
-
decoder(z)[source]¶ Computes the generative distribution p(x | z).
- Parameters
z: torch.Tensor (num_samples, mix_components, latent_size)
the stochastic latent state z.
- Returns
p(x | z): Bernoulli (batch_size, data_size)
a Bernouilli distribution.
-
encoder_gmm(x, y)[source]¶ Computes the inference distribution q(z | x, y).
- Parameters
x: torch.Tensor (batch_size, data_size)
the input data.
y: torch.Tensor (batch_size, mix_components)
discrete variable.
- Returns
q(z | x, y): MultivariateNormal (batch_size, latent_size)
a Multivariate Normal Diag distribution.
-
encoder_y(x)[source]¶ Computes the inference distribution q(y | x).
- Parameters
x: torch.Tensor (batch_size, data_size)
the input data to the inference network.
- Returns
q(y | x): RelaxedOneHotCategorical (batch_size, mix_components)
a relaxed one hot Categorical distribution.
-
generate_sample_data(z=None, num_samples=1)[source]¶ Generates mean sample data from the model.
Can provide latent variable ‘z’ to generate data for this point in the latent space, else draw from prior.
- Parameters
z: torch.Tensor (num_samples, mix_components, latent_size)
the stochastic latent state z.
- Returns
recon: torch.Tensor (batch_size, data_size)
the reconstructed mean samples data.
-
generate_samples(num_samples, clusters=None)[source]¶ Samples components from the static latent GMM prior.
- Parameters
num_samples: int
number of samples to draw from the static GMM prior.
clusters: list of int, default None
if desired, can sample from a specific batch of clusters.
- Returns
z: Tensor (num_samples, mix_components, latent_size)
representing samples drawn from each component of the GMM if clusters is None else if clusters the Tensor is of shape (num_samples, batch_size, latent_size) where batch_size is the first dimension of clusters, dependening on how many were supplied.
-
prior_gmm(y)[source]¶ Computes the GMM prior distribution p(z | y).
- Parameters
y: torch.Tensor (batch_size, mix_components)
the discrete intermediate variable y.
- Returns
p(z | y): MultivariateNormal (batch_size, latent_size)
a GMM distribution.
-
-
class
pynet.models.vae.gmvae.VAEGMPNet(input_dim, latent_dim, n_mix_components=1, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, gen_bias_init=0, dropout=0, prior=None, encoder=None, decoder=None, random_seed=None)[source]¶ Implementation of a Variational Autoencoder (VAE) with Gaussian Mixture Prior (GMP).
-
__init__(input_dim, latent_dim, n_mix_components=1, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, gen_bias_init=0, dropout=0, prior=None, encoder=None, decoder=None, random_seed=None)[source]¶ Init class.
- Parameters
input_dim: int
the input size.
latent_dim: int,
the size of the stochastic latent state of the GMVAE.
n_mix_components: int, default 1
the number of components in the mixture prior. If 1, a classical VAE is generated with prior z ~ N(0, 1).
dense_hidden_dims: list of int, default None
the sizes of the hidden layers of the fully connected network used to condition the distribution on the inputs. If None, then the default is a single-layered dense network.
sigma_min: float, default 0.001
the minimum value that the standard deviation of the distribution over the latent state can take.
raw_sigma_bias: float, default 0.25
a scalar that is added to the raw standard deviation output from the neural networks that parameterize the prior and approximate posterior. Useful for preventing standard deviations close to zero.
gen_bias_init: float, default 0
a bias to added to the raw output of the fully connected network that parameterizes the generative distribution. Useful for initalising the mean to a sensible starting point e.g. mean of training set.
dropout: float, default 0
define the dropout rate.
prior: @callable, default None
a distribution that implements p(z).
encoder: @callable, default None
a distribution that implements inference q(z | x).
decoder: @callable, default None
a distribution that implements p(x | z). Must accept as arguments the latent state z and return a distribution that can be used to evaluate the log_prob of the targets.
random_seed: int, default None
the seed for the random operations.
-
decoder(z)[source]¶ Computes the generative distribution p(x | z).
- Parameters
z: torch.Tensor (batch_size, latent_size)
the stochastic latent state z.
- Returns
p(x | z): @callable
the distribution p(x | z) with shape (batch_size, data_size).
-
encoder(x)[source]¶ Computes the inference distribution q(z | x).
- Parameters
x: torch.Tensor (batch_size, data_size)
the input data.
- Returns
q(z | x): @callable
the distribution q(z | x) with shape (batch_size, latent_size).
-
generate_sample_data(z=None, num_samples=1)[source]¶ Generates mean sample data from the model.
Can provide latent variable ‘z’ to generate data for this point in the latent space, else draw from prior.
- Parameters
z: torch.Tensor (num_samples, latent_size)
the stochastic latent state z.
- Returns
recon: torch.Tensor (batch_size, data_size)
the reconstructed mean samples data.
-
generate_samples(num_samples)[source]¶ Generate ‘num_samples’ samples from the model prior.
- Parameters
num_samples: int
number of samples to draw from the prior distribution.
- Returns
z: Tensor (num_samples, latent_size)
representing samples drawn from the prior distribution.
-
prior()[source]¶ Get the prior distribution p(z).
- Returns
p(z): @callable
the distribution p(z) with shape (batch_size, latent_size).
-
Follow us
Inspired by AZMIND template.