Helper Module for Deep Learning.
Source code for pynet.models.vae.mcvae
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Sparse Multi-Channel Variational Autoencoderfor the Joint Analysis of
Heterogeneous Data.
"""
# Imports
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
from torch.distributions import Normal, kl_divergence
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks, Losses
from .vae import VAENet
# Global parameters
logger = logging.getLogger("pynet")
[docs]@Networks.register
@DeepLearningDecorator(family=("encoder", "vae"))
class MCVAE(nn.Module):
""" Sparse Multi-Channel Variational Autoencoder (sMCVAE).
Sparse Multi-Channel Variational Autoencoder for the Joint Analysis of
Heterogeneous Data, Luigi Antelmi, Nicholas Ayache, Philippe Robert,
Marco Lorenzi Proceedings of the 36th International Conference on Machine
Learning, PMLR 97:302-311, 2019.
"""
[docs] def __init__(self, latent_dim, n_channels, n_feats,
noise_init_logvar=-3, noise_fixed=False, sparse=False,
vae_model="dense", vae_kwargs=None, nodecoding=False):
""" Init class.
Parameters
----------
latent_dim: int
the number of latent dimensions.
n_channels: int
the number of channels.
n_feats: list of int
each channel input dimensions.
noise_init_logvar: float, default -3
default noise parameters values.
noise_fixed: bool, default False
if set not set do not required gradients on noise parameters.
sparse: bool, default False
use sparsity contraint.
vae_model: str, default "dense"
the VAE network used to encode each channel.
vae_kwargs: dict, default None
extra parameters passed initialization of the VAE model.
nodecoding: bool, default False
if set do not apply the decoding.
"""
super(MCVAE, self).__init__()
assert(n_channels == len(n_feats))
self.latent_dim = latent_dim
self.n_channels = n_channels
self.n_feats = n_feats
self.sparse = sparse
self.noise_init_logvar = noise_init_logvar
self.noise_fixed = noise_fixed
if vae_model == "dense":
self.vae_class = VAENet
else:
raise ValueError("Unknown VAE model.")
self.vae_kwargs = vae_kwargs or {}
self.nodecoding = nodecoding
self.init_vae()
[docs] def init_vae(self):
""" Create one VAE model per channel.
"""
if self.sparse:
self.log_alpha = nn.Parameter(
torch.FloatTensor(1, self.latent_dim).normal_(0, 0.01))
else:
self.log_alpha = None
vae = []
for c_idx in range(self.n_channels):
if "conv_flts" not in self.vae_kwargs:
self.vae_kwargs["conv_flts"] = None
if "dense_hidden_dims" not in self.vae_kwargs:
self.vae_kwargs["dense_hidden_dims"] = None
vae.append(
self.vae_class(
input_channels=1,
input_dim=self.n_feats[c_idx],
latent_dim=self.latent_dim,
noise_out_logvar=self.noise_init_logvar,
noise_fixed=self.noise_fixed,
sparse=self.sparse,
act_func=torch.nn.Tanh,
final_activation=False,
log_alpha=self.log_alpha,
**self.vae_kwargs))
self.vae = torch.nn.ModuleList(vae)
[docs] def encode(self, x):
""" Encodes the input by passing through the encoder network
and returns the latent distribution for each channel.
Parameters
----------
x: list of Tensor, (C,) -> (N, Fc)
input tensors to encode.
Returns
-------
out: list of 2-uplet (C,) -> (N, D)
each channel distribution parameters mu (mean of the latent
Gaussian) and logvar (standard deviation of the latent Gaussian).
"""
return [self.vae[c_idx].encode(x[c_idx])
for c_idx in range(self.n_channels)]
[docs] def decode(self, z):
""" Maps the given latent codes onto the image space.
Parameters
----------
z: list of Tensor (N, D)
sample from the distribution having latent parameters mu, var.
Returns
-------
p: list of Tensor, (N, C, F)
the prediction p(x|z).
"""
p = []
for c_idx1 in range(self.n_channels):
pi = [self.vae[c_idx1].decode(z[c_idx2])
for c_idx2 in range(self.n_channels)]
p.append(pi)
del pi
return p
[docs] def reconstruct(self, p):
x_hat = []
for c_idx1 in range(self.n_channels):
x_tmp = torch.stack([
p[c_idx1][c_idx2].loc.detach()
for c_idx2 in range(self.n_channels)]).mean(dim=0)
x_hat.append(x_tmp.cpu().numpy())
del x_tmp
return x_hat
[docs] def forward(self, x):
qs = self.encode(x)
z = [q.rsample() for q in qs]
if self.nodecoding:
return z, {"q": qs, "x": x}
else:
p = self.decode(z)
return p, {"q": qs, "x": x}
[docs] def p_to_prediction(self, p):
""" Get the prediction from various types of distributions.
"""
if isinstance(p, list):
return [self.p_to_prediction(_p) for _p in p]
elif isinstance(p, Normal):
pred = p.loc.cpu().detach().numpy()
elif isinstance(p, Bernoulli):
pred = p.probs.cpu().detach().numpy()
else:
raise NotImplementedError
return pred
[docs] def apply_threshold(self, z, threshold, keep_dims=True, reorder=False):
""" Apply dropout threshold.
Parameters
----------
z: Tensor
distribution samples.
threshold: float
dropout threshold.
keep_dims: bool default True
dropout lower than threshold is set to 0.
reorder: bool default False
reorder dropout rates.
Returns
-------
z_keep: list
dropout rates.
"""
assert(threshold <= 1.0)
order = torch.argsort(self.dropout).squeeze()
keep = (self.dropout < threshold).squeeze()
z_keep = []
for drop in z:
if keep_dims:
drop[:, ~keep] = 0
else:
drop = drop[:, keep]
order = torch.argsort(
self.dropout[self.dropout < threshold]).squeeze()
if reorder:
drop = drop[:, order]
z_keep.append(drop)
del drop
return z_keep
@property
def dropout(self):
if self.sparse:
alpha = torch.exp(self.log_alpha.detach())
return alpha / (alpha + 1)
else:
raise NotImplementedError
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.