Helper Module for Deep Learning.
Source code for pynet.losses.generative
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Module that provides generative losses.
Code: https://github.com/YannDubs/disentangling-vae
"""
# Imports
import math
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn as nn
from torch.nn import functional as func
from torch.distributions import Bernoulli, Normal, Laplace, kl_divergence
from pynet.utils import Losses
[docs]def get_vae_loss(loss_name, **kwargs):
""" Return the correct VAE loss function given the input arguments.
The parameters for each loss:
- vae: -
- betah: beta
- betab: C_init, C_fin, gamma
- factor: device, gamma, latent_dim, lr_disc
- btcvae: dataset_size, alpha, beta, gamma
- sparse: beta
Parameters
----------
loss_name: str
the name of the loss.
kwargs: dict
the loss kwargs.
Returns
-------
loss: @callable
the loss function.
"""
common_kwargs = dict(steps_anneal=kwargs["steps_anneal"],
use_mse=kwargs["use_mse"])
if loss_name == "betah":
loss = BetaHLoss(beta=kwargs["beta"], **common_kwargs)
elif loss_name == "vae":
loss = BetaHLoss(beta=1, **common_kwargs)
elif loss_name == "betab":
loss = BetaBLoss(C_init=kwargs["C_init"], C_fin=kwargs["C_fin"],
gamma=kwargs["gamma"], **common_kwargs)
elif loss_name == "factor":
loss = FactorKLoss(gamma=kwargs["gamma"],
disc_kwargs=dict(latent_dim=kwargs["latent_dim"]),
optim_kwargs=dict(lr=kwargs["lr_disc"],
betas=(0.5, 0.9)),
**common_kwargs)
elif loss_name == "btcvae":
loss = BtcvaeLoss(dataset_size=kwargs["dataset_size"],
alpha=kwargs["alpha"], beta=kwargs["beta"],
gamma=kwargs["gamma"], is_mss=kwargs["is_mss"],
**common_kwargs)
elif loss_name == "sparse":
loss = SparseLoss(beta=kwargs["beta"], **common_kwargs)
else:
raise ValueError("Uknown loss: {}".format(loss_name))
return loss
[docs]class BaseLoss(object):
""" Base class for losses.
"""
[docs] def __init__(self, steps_anneal=0, use_mse=False):
""" Init class.
Parameters
----------
steps_anneal: int, default 0
number of annealing steps where gradually adding the
regularisation.
use_mse: bool, default False
if set use MSE for the reconstruction loss rather than Log
Likelihood.
"""
self.n_train_steps = 0
self.layer_outputs = None
self.steps_anneal = steps_anneal
self.use_mse = use_mse
self.cache = {}
[docs] def get_params(self):
""" Get forward layers outputs.
Returns
-------
q: torch.distributions
probabilistic encoder (or estimated posterior probability
function).
z: torch.Tensor
the compressed code learned in the bottleneck layer.
model: nn.Module
the network.
"""
if self.layer_outputs is None:
raise ValueError("The model needs to return the latent space "
"distribution parameters q and sampling z as "
"well as the model itself.")
z = self.layer_outputs["z"]
q = self.layer_outputs["q"]
model = self.layer_outputs["model"]
return q, z, model
[docs] def reconstruction_loss(self, p, data):
""" Calculates the per image reconstruction loss for a batch of data
(i.e. negative log likelihood).
The distribution of the likelihood on the each pixel implicitely
defines the loss. Bernoulli corresponds to a binary cross entropy.
Gaussian distribution corresponds to MSE, and is sometimes
used, but hard to train because it ends up focusing only a few
pixels that are very wrong. Laplace distribution corresponds to
L1 solves partially the issue of MSE.
Parameters
----------
p: torch.distributions
probabilistic decoder (or likelihood of generating true data
sample given the latent code).
data: torch.Tensor
reference data.
Returns
-------
loss: torch.Tensor
per image cross entropy (i.e. normalized per batch but not pixel
and channel).
"""
if isinstance(p, Bernoulli):
loss = func.binary_cross_entropy(p.probs, data, reduction="sum")
elif isinstance(p, Normal):
if self.use_mse:
loss = func.mse_loss(p.loc, data, reduction="sum")
else:
loss = self.compute_ll(p, data)
elif isinstance(p, Laplace):
loss = func.l1_loss(p.loc, data, reduction="sum")
# empirical value to give similar values than bernoulli => use
# same hyperparam
loss = loss * 3
loss = loss * (loss != 0) # masking to avoid nan
else:
raise ValueError("Unkown distribution: {}".format(distribution))
batch_size = len(data)
loss = loss / batch_size
if "iteration" not in self.cache:
self.cache["iteration"] = 0
self.cache["iteration"] += len(data)
self.cache.setdefault("ll", []).append(loss.detach().cpu().numpy())
return loss
[docs] def compute_ll(self, p, data):
""" Compute log likelihood.
Parameters
----------
p: torch.distributions
probabilistic decoder (or likelihood of generating true data
sample given the latent code).
data: torch.Tensor
reference data.
"""
return - p.log_prob(data).sum(-1, keepdim=True)
[docs] def kl_normal_loss(self, q):
""" Calculates the KL divergence between a normal distribution
with diagonal covariance and a unit normal distribution.
Parameters
----------
q: torch.distributions
probabilistic encoder (or estimated posterior probability
function).
"""
dimension_wise_kl = kl_divergence(q, Normal(0, 1)).mean(dim=0)
self.cache.setdefault("kl", []).append(
dimension_wise_kl.detach().cpu().numpy())
return dimension_wise_kl.sum()
[docs] @staticmethod
def kl_log_uniform(normal):
""" Calculates the KL log uniform divergence.
Paragraph 4.2 from:
Variational Dropout Sparsifies Deep Neural Networks
Molchanov, Dmitry; Ashukha, Arsenii; Vetrov, Dmitry
https://arxiv.org/abs/1701.05369
https://github.com/senya-ashukha/variational-dropout-sparsifies-dnn/
blob/master/KL%20approximation.ipynb
"""
mu = normal.loc
logvar = normal.scale.pow(2).log()
log_alpha = BaseLoss.compute_log_alpha(mu, logvar)
k1, k2, k3 = 0.63576, 1.8732, 1.48695
neg_kl = (k1 * torch.sigmoid(k2 + k3 * log_alpha) - 0.5 *
torch.log1p(torch.exp(-log_alpha)) - k1)
return - neg_kl.mean(dim=0).sum()
[docs] @staticmethod
def compute_log_alpha(mu, logvar):
return (logvar - 2 * torch.log(torch.abs(mu) + 1e-8)).clamp(
min=-8, max=8)
[docs] def linear_annealing(self, init, fin):
""" Linear annealing of a parameter.
Returns
-------
annealed: float
loss factor to gradually add the regularisation.
"""
if self.steps_anneal == 0:
return fin
assert fin > init
delta = fin - init
annealed = min(init + delta * self.n_train_steps /
self.steps_anneal, fin)
return annealed
[docs] def update_train_step(self, iteration=None):
""" Update the train step.
Parameters
----------
iteration: int, default None
the current iteration.
"""
if iteration is None and "iteration" in self.cache:
iteration = self.cache["iteration"]
if iteration is None:
raise ValueError("No iteration specified.")
self.n_train_steps = iteration
[docs]@Losses.register
class BetaHLoss(BaseLoss):
""" Compute the Beta-VAE loss.
beta-VAE: Learning Basic Visual Concepts with a Constrained Variational
Framework, Irina Higgins, ICLR 2017.
"""
[docs] def __init__(self, beta=4, **kwargs):
""" Init class.
Parameters
----------
beta: float, default 4
weight of the kl divergence.
kwargs: dict
additional arguments for 'BaseLoss'.
"""
super(BetaHLoss, self).__init__(**kwargs)
self.beta = beta
def __call__(self, p, data, **kwargs):
""" Compute the loss.
"""
q, z, model = self.get_params()
rec_loss = self.reconstruction_loss(p, data)
kl_loss = self.kl_normal_loss(q)
if model.training:
anneal_reg = self.linear_annealing(init=0, fin=1)
self.update_train_step()
else:
anneal_reg = 1
kl_loss = anneal_reg * (self.beta * kl_loss)
loss = rec_loss + kl_loss
return loss, {"rec_loss": rec_loss, "kl_loss": kl_loss}
[docs]@Losses.register
class BetaBLoss(BaseLoss):
""" Compute the Beta-VAE loss.
Understanding disentangling in beta-VAE, Burgess, arXiv 2018.
"""
[docs] def __init__(self, C_init=0., C_fin=20., gamma=100., **kwargs):
""" Init class.
Parameters
----------
C_init: float, default 0
starting annealed capacity C.
C_fin: float, default 20
final annealed capacity C.
gamma: float, default 100
weight of the KL divergence term.
kwargs: dict
additional arguments for 'BaseLoss'.
"""
super(BetaBLoss, self).__init__(**kwargs)
self.gamma = gamma
self.C_init = C_init
self.C_fin = C_fin
def __call__(self, p, data, **kwargs):
""" Compute the loss.
"""
q, z, model = self.get_params()
rec_loss = self.reconstruction_loss(p, data)
kl_loss = self.kl_normal_loss(q)
if model.training:
C = self.linear_annealing(init=self.C_init, fin=self.C_fin)
self.update_train_step()
else:
C = self.C_fin
kl_loss = self.gamma * (kl_loss - C).abs()
loss = rec_loss + kl_loss
return loss, {"rec_loss": rec_loss, "kl_loss": kl_loss}
[docs]@Losses.register
class SparseLoss(BaseLoss):
""" Compute the Beta-Sparse VAE loss.
Sparse Multi-Channel Variational Autoencoder for the Joint Analysis of
Heterogeneous Data, Luigi Antelmi, Nicholas Ayache, Philippe Robert,
Marco Lorenzi, PMLR 2019.
"""
[docs] def __init__(self, beta=4, **kwargs):
""" Init class.
Parameters
----------
beta: float, default 4
weight of the kl divergence.
kwargs: dict
additional arguments for 'BaseLoss'.
"""
super(SparseLoss, self).__init__(**kwargs)
self.beta = beta
def __call__(self, p, data, **kwargs):
""" Compute the loss.
"""
q, z, model = self.get_params()
rec_loss = self.reconstruction_loss(p, data)
kl_loss = BaseLoss.kl_log_uniform(q)
if model.training:
anneal_reg = self.linear_annealing(init=0, fin=1)
else:
anneal_reg = 1
kl_loss = anneal_reg * (self.beta * kl_loss)
loss = rec_loss + kl_loss
return loss, {"rec_loss": rec_loss, "kl_loss": kl_loss}
[docs]@Losses.register
class FactorKLoss(BaseLoss):
""" Compute the Factor-VAE loss (algorithm 2).
Disentangling by factorising, Hyunjik, arXiv 2018.
"""
[docs] def __init__(self, device, gamma=10., disc_kwargs={},
optim_kwargs=dict(lr=5e-5, betas=(0.5, 0.9)),
**kwargs):
""" Init class.
Parameters
----------
device: torch.device
the device.
optimizer: torch.optim
the network optimizer.
gamma: float, default 10
Weight of the TC loss term. `gamma` in the paper.
disc_kwargs: dict
discrimiator arguments.
optim_kwargs: dict
Adam optimizer arguments.
kwargs: dict
additional arguments for 'BaseLoss'.
"""
super(FactorKLoss, self).__init__(**kwargs)
self.gamma = gamma
self.device = device
self.optimizer = optimier
self.discriminator = Discriminator(**disc_kwargs).to(device)
self.optimizer_d = optim.Adam(self.discriminator.parameters(),
**optim_kwargs)
def __call__(self, *args, **kwargs):
""" Compute the loss.
"""
raise NotImplementedError
[docs]@Losses.register
class BtcvaeLoss(BaseLoss):
""" Compute the decomposed KL loss with either minibatch weighted
sampling or minibatch stratified sampling according.
Isolating sources of disentanglement in variational autoencoders, Tian Qi,
Advances in Neural Information Processing Systems, 2018.
"""
[docs] def __init__(self, dataset_size, alpha=1., beta=6., gamma=1.,
is_mss=True, **kwargs):
""" Init class.
Parameters
----------
dataset_size: int
number of training images in the dataset.
alpha: float, default 1
weight of the mutual information term.
beta: float, default 6
weight of the total correlation term.
gamma: float, default 1
weight of the dimension-wise KL term.
dataset_size: int
number of training images in the dataset.
is_mss: bool, default True
wether to use minibatch stratified sampling instead of minibatch
weighted sampling.
kwargs: dict
additional arguments for 'BaseLoss'.
"""
super(BtcvaeLoss, self).__init__(**kwargs)
self.dataset_size = dataset_size
self.beta = beta
self.alpha = alpha
self.gamma = gamma
self.is_mss = is_mss
def __call__(self, p, data):
""" Compute the loss.
"""
q, z, model = self.get_params()
rec_loss = self.reconstruction_loss(p, data)
log_pz, log_qz, log_prod_qzi, log_qz_x = self.get_probs(z, q)
# I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
mi_loss = (log_qz_x - log_qz).mean()
# TC[z] = KL[q(z)||\prod_i z_i]
tc_loss = (log_qz - log_prod_qzi).mean()
# dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
dw_kl_loss = (log_prod_qzi - log_pz).mean()
if model.training:
anneal_reg = self.linear_annealing(init=0, fin=1)
self.update_train_step()
else:
anneal_reg = 1
mi_loss = self.alpha * mi_loss
tc_loss = self.beta * tc_loss
dw_kl_loss = anneal_reg * self.gamma * dw_kl_loss
loss = rec_loss + mi_loss + tc_loss + dw_kl_loss
return loss, {"mi_loss": mi_loss, "tc_loss": tc_loss,
"dw_kl_loss": dw_kl_loss}
[docs] def get_probs(self, z, q):
# Calculate log q(z|x)
log_qz_x = q.log_prob(z).sum(dim=1)
# Calculate log p(z)
pz = Normal(loc=torch.zeros_like(z), scale=1)
log_pz = pz.log_prob(z).sum(dim=1)
# Calculate log q(z)
batch_size = len(z)
mat_log_qz = BtcvaeLoss.matrix_log_density_gaussian(z, q)
if self.is_mss:
log_iw_mat = BtcvaeLoss.log_importance_weight_matrix(
batch_size, self.dataset_size)
log_iw_mat = torch.unsqueeze(log_iw_mat, dim=-1).to(z.device)
mat_log_qz = mat_log_qz + log_iw_mat
log_qz = torch.logsumexp(mat_log_qz.sum(dim=2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(
mat_log_qz, dim=1, keepdim=False).sum(dim=1)
return log_pz, log_qz, log_prod_qzi, log_qz_x
[docs] @staticmethod
def matrix_log_density_gaussian(x, q):
""" Calculates log density of a Gaussian for all combination of bacth
pairs of 'x' and 'mu', i.e. return tensor of shape (batch_size,
batch_size, dim) instead of (batch_size, dim) in the usual log density.
Parameters
----------
x: torch.Tensor (batch_size, dim)
value at which to compute the density.
q: torch.distributions
probabilistic encoder (or estimated posterior probability
function).
"""
x = torch.unsqueeze(x, dim=1)
_mu = torch.unsqueeze(q.loc, dim=0)
_sigma = torch.unsqueeze(q.scale, dim=0)
_q = Normal(loc=_mu, scale=_sigma)
return _q.log_prob(x)
[docs] @staticmethod
def log_importance_weight_matrix(batch_size, dataset_size):
""" Calculates a log importance weight matrix.
Parameters
----------
batch_size: int
number of training images in the batch.
dataset_size: int
number of training images in the dataset.
"""
N = dataset_size
M = batch_size - 1
strat_weight = (N - M) / (N * M)
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
W.view(-1)[::M + 1] = 1 / N
W.view(-1)[1::M + 1] = strat_weight
W[M - 1, 0] = strat_weight
return W.log()
[docs]@Losses.register
class VAEGMPLoss(object):
""" VAEGMP Loss.
"""
[docs] def __init__(self):
""" Init class.
"""
super(VAEGMPLoss, self).__init__()
self.layer_outputs = None
def __call__(self, p_x_given_z, data):
""" Compute loss.
"""
if self.layer_outputs is None:
raise ValueError(
"This loss needs intermediate layers outputs. Please register "
"an appropriate callback.")
q_z_given_x = self.layer_outputs["q_z_given_x"]
z = self.layer_outputs["z"]
p_z = self.layer_outputs["p_z"]
# Reconstruction loss term i.e. the negative log-likelihood
nll = - p_x_given_z.log_prob(data).sum()
nll /= len(data)
# Latent loss between approximate posterior and prior for z
kl_div_z = (q_z_given_x.log_prob(z) - p_z.log_prob(z)).sum()
kl_div_z /= len(data)
# Need to maximise the ELBO with respect to these weights
loss = nll + kl_div_z
return loss, {"nll": nll, "kl_div_z": kl_div_z}
[docs]@Losses.register
class GMVAELoss(object):
""" GMVAE Loss.
"""
[docs] def __init__(self):
""" Init class.
"""
super(GMVAELoss, self).__init__()
self.layer_outputs = None
def __call__(self, p_x_given_z, data, labels=None):
""" Compute loss.
"""
if self.layer_outputs is None:
raise ValueError(
"This loss needs intermediate layers outputs. Please register "
"an appropriate callback.")
q_y_given_x = self.layer_outputs["q_y_given_x"]
y = self.layer_outputs["y"]
q_z_given_xy = self.layer_outputs["q_z_given_xy"]
z = self.layer_outputs["z"]
p_z_given_y = self.layer_outputs["p_z_given_y"]
# Reconstruction loss term i.e. the negative log-likelihood
nll = - p_x_given_z.log_prob(data).sum()
nll /= len(data)
# Latent loss between approximate posterior and prior for z
kl_div_z = (q_z_given_xy.log_prob(z) - p_z_given_y.log_prob(z)).sum()
kl_div_z /= len(data)
# Conditional entropy loss
logits = q_y_given_x.logits
probs = func.softmax(logits, dim=-1)
nent = (- probs * torch.log(probs)).sum()
nent /= len(data)
# Need to maximise the ELBO with respect to these weights
loss = nll + kl_div_z + nent
# Keep track of the clustering accuracy during training
if labels is not None:
cluster_acc = GMVAELoss.cluster_acc(
q_y_given_x.logits, labels, is_logits=True)
else:
cluster_acc = 0
return loss, {"nll": nll, "kl_div_z": kl_div_z, "nent": nent,
"cluster_acc": cluster_acc}
[docs] @staticmethod
def cluster_acc(y_pred, y, is_logits=False):
# assert y_pred.size == y.size
if isinstance(y_pred, torch.Tensor):
y_pred = y_pred.detach().cpu().numpy()
if isinstance(y, torch.Tensor):
y = y.detach().cpu().numpy()
if is_logits:
y_pred = np.argmax(y_pred, axis=1)
n_classes = max(y_pred.max(), y.max()) + 1
gain = np.zeros((n_classes, n_classes), dtype=np.int64)
for idx in range(y_pred.size):
gain[y_pred[idx], y[idx]] += 1
cost = gain.max() - gain
row_ind, col_ind = linear_sum_assignment(cost)
return gain[row_ind, col_ind].sum() / y_pred.size
[docs]@Losses.register
class MCVAELoss(object):
""" MCVAE loss.
Sparse Multi-Channel Variational Autoencoder for the Joint Analysis of
Heterogeneous Data, Luigi Antelmi, Nicholas Ayache, Philippe Robert,
Marco Lorenzi Proceedings of the 36th International Conference on Machine
Learning, PMLR 97:302-311, 2019.
MCVAE consists of two loss functions:
1. KL divergence loss: how off the distribution over the latent space is
from the prior. Given the prior is a standard Gaussian and the inferred
distribution is a Gaussian with a diagonal covariance matrix,
the KL-divergence becomes analytically solvable.
2. log-likelihood LL
loss = beta * KL_loss + LL_loss.
"""
[docs] def __init__(self, n_channels, beta=1., enc_channels=None,
dec_channels=None, sparse=False, nodecoding=False):
""" Init class.
Parameters
----------
n_channels: int
the number of channels.
beta, float, default 1.
for beta-VAE.
enc_channels: list of int, default None
encode only these channels (for kl computation).
dec_channels: list of int, default None
decode only these channels (for ll computation).
sparse: bool, default False
use sparsity contraint.
nodecoding: bool, default False
if set do not apply the decoding.
"""
super(MCVAELoss, self).__init__()
self.n_channels = n_channels
self.beta = beta
self.sparse = sparse
self.enc_channels = enc_channels
self.dec_channels = dec_channels
if enc_channels is None:
self.enc_channels = list(range(n_channels))
else:
assert(len(enc_channels) <= n_channels)
if dec_channels is None:
self.dec_channels = list(range(n_channels))
else:
assert(len(dec_channels) <= n_channels)
self.n_enc_channels = len(self.enc_channels)
self.n_dec_channels = len(self.dec_channels)
self.nodecoding = nodecoding
self.layer_outputs = None
def __call__(self, p):
""" Compute loss.
Parameters
----------
p: list of Normal distributions (C,) -> (N, F)
reconstructed channels data.
x: list of Tensor, (C,) -> (N, F)
inputs channels data.
"""
if self.nodecoding:
return -1
if self.layer_outputs is None:
raise ValueError(
"This loss needs intermediate layers outputs. Please register "
"an appropriate callback.")
x = self.layer_outputs["x"]
q = self.layer_outputs["q"]
kl = self.compute_kl(q, self.beta)
ll = self.compute_ll(p, x)
total = kl - ll
return total, {"kl": kl, "ll": ll}
[docs] def compute_kl(self, q, beta):
kl = 0
if not self.sparse:
for c_idx, qi in enumerate(q):
if c_idx in self.enc_channels:
kl += kl_divergence(qi, Normal(
0, 1)).sum(-1, keepdim=True).mean(0)
else:
for c_idx, qi in enumerate(q):
if c_idx in self.enc_channels:
kl += self._kl_log_uniform(qi).sum(
-1, keepdim=True).mean(0)
return beta * kl / self.n_enc_channels
[docs] def compute_ll(self, p, x):
# p[x][z]: p(x|z)
ll = 0
for c_idx1 in range(self.n_channels):
for c_idx2 in range(self.n_channels):
if c_idx1 in self.dec_channels and c_idx2 in self.enc_channels:
ll += self._compute_ll(
p=p[c_idx1][c_idx2], x=x[c_idx1]).mean(0)
return ll / self.n_enc_channels / self.n_dec_channels
[docs] def compute_log_alpha(self, mu, logvar):
# clamp because dropout rate p in 0-99%, where p = alpha/(alpha+1)
return (logvar - 2 * torch.log(torch.abs(mu) + 1e-8)).clamp(
min=-8, max=8)
def _compute_ll(self, p, x):
ll = p.log_prob(x).view(len(x), -1)
return ll.sum(-1, keepdim=True)
def _kl_log_uniform(self, normal):
"""
Paragraph 4.2 from:
Variational Dropout Sparsifies Deep Neural Networks
Molchanov, Dmitry; Ashukha, Arsenii; Vetrov, Dmitry
https://arxiv.org/abs/1701.05369
https://github.com/senya-ashukha/variational-dropout-sparsifies-dnn/
blob/master/KL%20approximation.ipynb
"""
mu = normal.loc
logvar = normal.scale.pow(2).log()
log_alpha = self.compute_log_alpha(mu, logvar)
k1, k2, k3 = 0.63576, 1.8732, 1.48695
neg_kl = (k1 * torch.sigmoid(k2 + k3 * log_alpha) - 0.5 *
torch.log1p(torch.exp(-log_alpha)) - k1)
return - neg_kl
[docs]@Losses.register
class VaDELoss(object):
""" VaDE loss.
"""
[docs] def __init__(self, alpha=1):
""" Init class.
Parameters
----------
alpha: float, default 1
reconstruction loss weight.
"""
super(VaDELoss, self).__init__()
self.layer_outputs = None
self.alpha = alpha
def __call__(self, x_pred, x, *args, **kwargs):
if self.layer_outputs is None:
raise ValueError("The model needs to return the latent space "
"distribution parameters z, z_mu, z_var.")
z_mu = self.layer_outputs["z_mu"]
z_logvar = self.layer_outputs["z_logvar"]
z = self.layer_outputs["z"]
model = self.layer_outputs["model"]
n_classes = model.n_classes
batch_size = x.size(0)
gamma = model.get_gamma(z, z_mu, z_logvar)
z_mu_t = z_mu.unsqueeze(dim=2).expand(
z_mu.size()[0], z_mu.size()[1], n_classes)
z_logvar_t = z_logvar.unsqueeze(dim=2).expand(
z_logvar.size()[0], z_logvar.size()[1], n_classes)
u_t = model.u_p.unsqueeze(dim=0).expand(
z.size()[0], model.u_p.size()[0], model.u_p.size()[1])
lambda_t = model.lambda_p.unsqueeze(dim=0).expand(
z.size()[0], model.lambda_p.size()[0], model.lambda_p.size()[1])
theta_t = model.theta_p.unsqueeze(dim=0).expand(
z.size()[0], n_classes)
if model.binary:
rec_loss = torch.sum(
func.binary_cross_entropy(x_pred, x, reduction="none"), dim=1)
else:
rec_loss = torch.sum(
func.mse_loss(x_pred, x, reduction="none"), dim=1)
# rec_loss *= self.alpha * model.input_dim
log_p_z_c = (
torch.sum(
0.5 * gamma
* torch.sum((
math.log(2 * math.pi) + torch.log(lambda_t)
+ torch.exp(z_logvar_t) / lambda_t
+ (z_mu_t - u_t)**2 / lambda_t), dim=1), dim=1))
q_entropy = (
- 0.5 * torch.sum(1 + z_logvar + math.log(2 * math.pi), dim=1))
log_p_c = - torch.sum(gamma * torch.log(theta_t), dim=1)
log_q_c_x = torch.sum(gamma * torch.log(gamma), dim=1)
# Normalize by same number of elements as in reconstruction
loss = torch.mean(rec_loss + log_p_z_c + q_entropy + log_p_c +
log_q_c_x)
return loss
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.