Menu

Helper Module for Deep Learning.

Mixture of Experts VAE with similarity constraint: MoE-Sim-VAE.

Reference: Mixture-of-Experts Variational Autoencoder for Clustering and Generating from Similarity-Based Representations on Single Cell Data, Andreas Kopf, arXiv 2020.

class pynet.models.vae.moevae.MOESimVAENet(input_dim, latent_dim, n_mix_components=1, dense_hidden_dims=None, classifier_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, gen_bias_init=0, dropout=0.5, random_seed=None)[source]

Implementation of a Mixture of Experts VAE with similarity constraint.

__init__(input_dim, latent_dim, n_mix_components=1, dense_hidden_dims=None, classifier_hidden_dims=None, sigma_min=0.001, raw_sigma_bias=0.25, gen_bias_init=0, dropout=0.5, random_seed=None)[source]

Init class.

Parameters

input_dim: int

the input size.

latent_dim: int,

the size of the stochastic latent state of the GMVAE.

n_mix_components: int, default 1

the number of components in the mixture prior. If 1, a classical VAE is generated with prior z ~ N(0, 1).

dense_hidden_dims: list of int, default None

the sizes of the hidden layers of the fully connected network used to condition the distribution on the inputs. If None, then the default is a single-layered dense network.

classifier_hidden_dims: list of int, default None

the sizes of the hidden layers of the classifier.

sigma_min: float, default 0.001

the minimum value that the standard deviation of the distribution over the latent state can take.

raw_sigma_bias: float, default 0.25

a scalar that is added to the raw standard deviation output from the neural networks that parameterize the prior and approximate posterior. Useful for preventing standard deviations close to zero.

gen_bias_init: float, default 0

a bias to added to the raw output of the fully connected network that parameterizes the generative distribution. Useful for initalising the mean to a sensible starting point e.g. mean of training set.

dropout: float, default 0.5

define the dropout rate.

random_seed: int, default None

the seed for the random operations.

decoder(z)[source]

Computes the generative distribution p(x | z).

Parameters

z: torch.Tensor (num_samples, mix_components, latent_size)

the stochastic latent state z.

Returns

p(x | z): list of Bernoulli (n_mix_components, )

a Bernouilli distribution for each decoder.

forward(x)[source]

The forward method.

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.