Menu

Helper Module for Deep Learning.

Source code for pynet.models.vae.gmvae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Gaussian Mixture Variational Auto-Encoder (GMVAE).

Two implementations are proposed:

* VAEGMP is an adaptation of VAE to make use of a Gaussian Mixture prior,
  instead of a standard Normal distribution.
* GMVAE is an attempt to replicate the work described in [1] and [2]

[1] Gaussian Mixture VAE: Lessons in Variational Inference, Generative Models,
and Deep Nets: http://ruishu.io/2016/12/25/gmvae
[2] Deep Unsupervised Clustering with Gaussian Mixture Variational Autoencoders
Nat Dilokthanakul, arXiv 2017.
Code: https://github.com/jariasf/GMVAE
Code: https://github.com/mazrk7/gmvae
"""

# Imports
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
from torch.distributions import Normal, Categorical, Independent
try:
    from torch.distributions import MixtureSameFamily
except:
    pass
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks, init_weight
from pynet.models.vae.vae import Encoder
from pynet.models.vae.distributions import (
    ConditionalNormal, Gaussian, ConditionalBernoulli, ConditionalCategorical)


[docs]@Networks.register @DeepLearningDecorator(family=("encoder", "vae", "classifier")) class GMVAENet(nn.Module): """ The Gaussian Mixture VAE architecture. Meta-GMVAE: Mixture of Gaussian VAE for Unsupervised Meta-Learning Dong Bok Lee, ICLR 2021. Gaussian Mixture VAE: Lessons in Variational Inference, Generative Models, and Deep Nets: http://ruishu.io/2016/12/25/gmvae Deep Unsupervised Clustering with Gaussian Mixture Variational Autoencoders Nat Dilokthanakul, arXiv 2017. """
[docs] def __init__(self, input_dim, latent_dim, n_mix_components, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, dropout=0, temperature=1, gen_bias_init=0., prior_gmm=None, decoder=None, encoder_y=None, encoder_gmm=None, random_seed=None): """ Init class. Parameters ---------- input_dim: int the input size. latent_dim: int, the size of the stochastic latent state of the GMVAE. n_mix_components: int the number of mixture components. dense_hidden_dims: list of int, default None the sizes of the hidden layers of the fully connected network used to condition the distribution on the inputs. If None, then the default is a single-layered dense network. sigma_min: float, default 0.001 the minimum value that the standard deviation of the distribution over the latent state can take. raw_sigma_bias: float, default 0.25 a scalar that is added to the raw standard deviation output from the neural networks that parameterize the prior and approximate posterior. Useful for preventing standard deviations close to zero. dropout: float, default 0 define the dropout rate. temperature: float, default 1 degree of how approximately discrete the distribution is. The closer to 0, the more discrete and the closer to infinity, the more uniform. gen_bias_init: float, default 0 a bias to added to the raw output of the fully connected network that parameterizes the generative distribution. Useful for initalising the mean to a sensible starting point e.g. mean of training set. prior_gmm: @callable, default None a callable that implements the prior distribution p(z | y) Must accept as argument the y discrete variable and return a tf.distributions.MultivariateNormalDiag distribution. decoder: : @callable, default None a callable that implements the generative distribution p(x | z). Must accept as arguments the encoded latent state z and return a subclass of tf.distributions.Distribution that can be used to evaluate the log_prob of the targets. encoder_y: : @callable, default None a callable that implements the inference q(y | x) over the discrete latent variable y. encoder_gmm: : @callable, default None a callable that implements the inference q(z | x, y) over the continuous latent variable z. random_seed: int, default None the seed for the random operations. """ super(GMVAENet, self).__init__() self.n_mix_components = n_mix_components self.random_seed = random_seed # Prior p(z | y) is a learned mixture of Gaussians, where mu and # sigma are output from a fully connected network conditioned on y if prior_gmm is not None: self._prior_gmm = prior_gmm else: self._prior_gmm = ConditionalNormal( input_dim=n_mix_components, final_dim=latent_dim, dense_hidden_dims=None, sigma_min=sigma_min, raw_sigma_bias=raw_sigma_bias, dropout=dropout) # The generative distribution p(x | z) is conditioned on the latent # state variable z via a fully connected network if decoder is not None: self._decoder = decoder else: self._decoder = ConditionalBernoulli( input_dim=latent_dim, final_dim=input_dim, dense_hidden_dims=dense_hidden_dims, bias_init=gen_bias_init) # A callable that implements the inference distribution q(y | x) # Use the Gumbel-Softmax distribution to model the categorical latent # variable if encoder_y is not None: self._encoder_y = encoder_y else: self._encoder_y = ConditionalCategorical( input_dim=input_dim, final_dim=n_mix_components, temperature=temperature, dense_hidden_dims=dense_hidden_dims) # A callable that implements the inference distribution q(z | x, y) if encoder_gmm is not None: self._encoder_gmm = encoder_gmm else: self._encoder_gmm = ConditionalNormal( input_dim=input_dim + n_mix_components, final_dim=latent_dim, dense_hidden_dims=dense_hidden_dims, sigma_min=sigma_min, raw_sigma_bias=raw_sigma_bias)
[docs] def prior_gmm(self, y): """ Computes the GMM prior distribution p(z | y). Parameters ---------- y: torch.Tensor (batch_size, mix_components) the discrete intermediate variable y. Returns ------- p(z | y): MultivariateNormal (batch_size, latent_size) a GMM distribution. """ return self._prior_gmm([y])
[docs] def decoder(self, z): """ Computes the generative distribution p(x | z). Parameters ---------- z: torch.Tensor (num_samples, mix_components, latent_size) the stochastic latent state z. Returns ------- p(x | z): Bernoulli (batch_size, data_size) a Bernouilli distribution. """ return self._decoder([z])
[docs] def encoder_y(self, x): """ Computes the inference distribution q(y | x). Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data to the inference network. Returns ------- q(y | x): RelaxedOneHotCategorical (batch_size, mix_components) a relaxed one hot Categorical distribution. """ return self._encoder_y([x])
[docs] def encoder_gmm(self, x, y): """ Computes the inference distribution q(z | x, y). Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data. y: torch.Tensor (batch_size, mix_components) discrete variable. Returns ------- q(z | x, y): MultivariateNormal (batch_size, latent_size) a Multivariate Normal Diag distribution. """ return self._encoder_gmm([x, y])
[docs] def reconstruct(self, x): """ Reconstruct the data from the model. Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data. Returns ------- recon: torch.Tensor (batch_size, data_size) the reconstruucted data. """ z = self.transform(x) recon = self.generate_sample_data(z=z) return recon
[docs] def generate_sample_data(self, z=None, num_samples=1): """ Generates mean sample data from the model. Can provide latent variable 'z' to generate data for this point in the latent space, else draw from prior. Parameters ---------- z: torch.Tensor (num_samples, mix_components, latent_size) the stochastic latent state z. Returns ------- recon: torch.Tensor (batch_size, data_size) the reconstructed mean samples data. """ if z is None: z = self.generate_samples(num_samples) p_x_given_z = self.decoder(z) recon = p_x_given_z.mean() return recon
[docs] def transform(self, x): """ Transform inputs 'x' to yield mean latent code. Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data. Returns ------- z: torch.Tensor (num_samples, mix_components, latent_size) the stochastic latent state z. """ q_y_given_x = self.encoder_y(x) if self.random_seed is not None: torch.manual_seed(self.random_seed) y = q_y_given_x.sample() q_z_givenxy = self.encoder_gmm(x, y) if self.random_seed is not None: torch.manual_seed(self.random_seed) z = q_z_givenxy.sample() return z
[docs] def generate_samples(self, num_samples, clusters=None): """ Samples components from the static latent GMM prior. Parameters ---------- num_samples: int number of samples to draw from the static GMM prior. clusters: list of int, default None if desired, can sample from a specific batch of clusters. Returns ------- z: Tensor (num_samples, mix_components, latent_size) representing samples drawn from each component of the GMM if clusters is None else if clusters the Tensor is of shape (num_samples, batch_size, latent_size) where batch_size is the first dimension of clusters, dependening on how many were supplied. """ # If no specific clusters supplied, sample from each component in GMM # Generate outputs over each component in GMM if clusters is None: clusters = torch.range(0, self.n_mix_components) y = func.one_hot(clusters, self.n_mix_components) p_z_given_y = self.prior_gmm(y) # Draw 'num_samples' samples from each cluster # Return shape: [num_samples, mix_components, latent_size] if self.random_seed is not None: torch.manual_seed(self.random_seed) z = p_z_given_y.sample(num_samples) # z = torch.reshape(z, [num_samples * self.n_mix_components, -1]) # z = torch.reshape(z, [num_samples * clusters.size(dim=0), -1]) return z
[docs] def forward(self, x): """ The forward method. """ # Encoder accepts images x and implements q(y | x) q_y_given_x = self.encoder_y(x) # Sample categorical variable y from the Gumbel-Softmax distribution if self.random_seed is not None: torch.manual_seed(self.random_seed) y = q_y_given_x.sample() # Prior accepts y as input and implements p(z | y) p_z_given_y = self.prior_gmm(y) # Encoder accept images x and y as inputs to implement q(z | x, y) q_z_given_xy = self.encoder_gmm(x, y) # Sample latent Gaussian variable z if self.random_seed is not None: torch.manual_seed(self.random_seed) z = q_z_given_xy.rsample() # Generative distribution p(x | z) p_x_given_z = self.decoder(z) return p_x_given_z, {"q_y_given_x": q_y_given_x, "y": y, "p_z_given_y": p_z_given_y, "q_z_given_xy": q_z_given_xy, "z": z}
[docs]@Networks.register @DeepLearningDecorator(family=("encoder", "vae")) class VAEGMPNet(nn.Module): """ Implementation of a Variational Autoencoder (VAE) with Gaussian Mixture Prior (GMP). """
[docs] def __init__(self, input_dim, latent_dim, n_mix_components=1, dense_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, gen_bias_init=0, dropout=0, prior=None, encoder=None, decoder=None, random_seed=None): """ Init class. Parameters ---------- input_dim: int the input size. latent_dim: int, the size of the stochastic latent state of the GMVAE. n_mix_components: int, default 1 the number of components in the mixture prior. If 1, a classical VAE is generated with prior z ~ N(0, 1). dense_hidden_dims: list of int, default None the sizes of the hidden layers of the fully connected network used to condition the distribution on the inputs. If None, then the default is a single-layered dense network. sigma_min: float, default 0.001 the minimum value that the standard deviation of the distribution over the latent state can take. raw_sigma_bias: float, default 0.25 a scalar that is added to the raw standard deviation output from the neural networks that parameterize the prior and approximate posterior. Useful for preventing standard deviations close to zero. gen_bias_init: float, default 0 a bias to added to the raw output of the fully connected network that parameterizes the generative distribution. Useful for initalising the mean to a sensible starting point e.g. mean of training set. dropout: float, default 0 define the dropout rate. prior: @callable, default None a distribution that implements p(z). encoder: @callable, default None a distribution that implements inference q(z | x). decoder: @callable, default None a distribution that implements p(x | z). Must accept as arguments the latent state z and return a distribution that can be used to evaluate the log_prob of the targets. random_seed: int, default None the seed for the random operations. """ super(VAEGMPNet, self).__init__() self.n_mix_components = n_mix_components self.random_seed = random_seed # Prior p(z) is a learned mixture of Gaussians, where mu and # sigma are output from a fully connected network if prior is not None: self._prior = prior else: if self.n_mix_components > 1: loc = nn.Parameter(torch.zeros( self.n_mix_components, latent_dim), requires_grad=True) raw_scale = nn.Parameter(torch.ones( self.n_mix_components, latent_dim), requires_grad=True) mixture_probs = nn.Parameter(torch.ones( self.n_mix_components) / self.n_mix_components, requires_grad=True) mix = Categorical(probs=mixture_probs) comp = Independent( Normal(loc=loc, scale=func.softplus(raw_scale)), reinterpreted_batch_ndims=1) self._prior = MixtureSameFamily( mixture_distribution=mix, component_distribution=comp) else: self._prior = Normal(loc=torch.zeros(latent_dim), scale=1) # The generative distribution p(x | z) is conditioned on the latent # state variable z via a fully connected network if decoder is not None: self._decoder = decoder else: self._decoder = ConditionalBernoulli( input_dim=latent_dim, final_dim=input_dim, dense_hidden_dims=dense_hidden_dims, bias_init=gen_bias_init) # A callable that implements the inference distribution q(z | x) if encoder is not None: self._encoder = encoder else: self._encoder = ConditionalNormal( input_dim=input_dim, final_dim=latent_dim, dense_hidden_dims=dense_hidden_dims, sigma_min=sigma_min, raw_sigma_bias=raw_sigma_bias)
[docs] def prior(self): """ Get the prior distribution p(z). Returns ------- p(z): @callable the distribution p(z) with shape (batch_size, latent_size). """ return self._prior
[docs] def decoder(self, z): """ Computes the generative distribution p(x | z). Parameters ---------- z: torch.Tensor (batch_size, latent_size) the stochastic latent state z. Returns ------- p(x | z): @callable the distribution p(x | z) with shape (batch_size, data_size). """ return self._decoder([z])
[docs] def encoder(self, x): """ Computes the inference distribution q(z | x). Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data. Returns ------- q(z | x): @callable the distribution q(z | x) with shape (batch_size, latent_size). """ return self._encoder([x])
[docs] def reconstruct(self, x): """ Reconstruct the data from the model. Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data. Returns ------- recon: torch.Tensor (batch_size, data_size) the reconstructed data. """ q_z = self.encoder(x) if self.random_seed is not None: torch.manual_seed(self.random_seed) if self.training: z = q_z.sample() else: z = q_z.mean() p_x_given_z = self.decoder(z) recon = p_x_given_z.mean() return recon
[docs] def generate_sample_data(self, z=None, num_samples=1): """ Generates mean sample data from the model. Can provide latent variable 'z' to generate data for this point in the latent space, else draw from prior. Parameters ---------- z: torch.Tensor (num_samples, latent_size) the stochastic latent state z. Returns ------- recon: torch.Tensor (batch_size, data_size) the reconstructed mean samples data. """ if z is None: z = self.generate_samples(num_samples) p_x_given_z = self.decoder(z) sample_images = p_x_given_z.mean() return sample_images
[docs] def transform(self, x): """ Transform inputs 'x' to yield mean latent code. Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data. Returns ------- z: torch.Tensor (num_samples, latent_size) the stochastic latent state z. """ q_z = self.encoder(inputs) z = q_z.mean() return z
[docs] def generate_samples(self, num_samples): """ Generate 'num_samples' samples from the model prior. Parameters ---------- num_samples: int number of samples to draw from the prior distribution. Returns ------- z: Tensor (num_samples, latent_size) representing samples drawn from the prior distribution. """ p_z = self.prior() if self.random_seed is not None: torch.manual_seed(self.random_seed) z = p_z.sample(num_samples) return z
[docs] def forward(self, x): """ The forward method. """ # Prior with Gaussian distribution p(z) p_z = self.prior() # Encoder accept images x implement q(z | x) q_z_given_x = self.encoder(x) if self.random_seed is not None: torch.manual_seed(self.random_seed) z = q_z_given_x.rsample() # Generative distribution p(x | z) p_x_given_z = self.decoder(z) return p_x_given_z, {"p_z": p_z, "q_z_given_x": q_z_given_x, "z": z}

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.