Menu

Helper Module for Deep Learning.

Source code for pynet.models.vae.vae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################


"""
Variational Auto-Encoder (VAE).
"""


# Imports
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
from torch.distributions import Normal
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks, init_weight


# Global parameters
logger = logging.getLogger("pynet")


[docs]class Encoder(nn.Module): """ The encoder part of a VAE. """
[docs] def __init__(self, input_channels, input_dim, conv_flts, dense_hidden_dims, latent_dim, act_func=None, dropout=0, log_alpha=None): """ Init class. Parameters ---------- input_channels: int the number of input channels. input_dim: int or list of int the size of input. conv_flts: list of int the size of convolutional filters, if None do not include convolutional layers. dense_hidden_dims: list of int the size of dense hidden dimensions, if None do not include dense hidden layers. latent_dim: int the latent dimension. act_func: callable, default None the activation function. dropout: float, default 0 define the dropout rate. log_alpha: nn.Parameter, default None inducing sparse latent representations. """ super(Encoder, self).__init__() self.act_func = act_func or nn.ReLU self.log_alpha = log_alpha if isinstance(input_dim, torch.Size): input_dim = list(input_dim) elif not isinstance(input_dim, list): input_dim = [input_dim] ndim = len(input_dim) if conv_flts is not None: w_conv_layers = Encoder.init_conv_layers( input_channels, conv_flts, self.act_func, dropout, ndim) self.w_conv = nn.Sequential(*w_conv_layers) flatten_dim = ( conv_flts[-1] * np.prod(Encoder.final_conv_dim( input_dim, kernels=[5] * (len(conv_flts) - 1) + [3], paddings=[2] * len(conv_flts))[-1])) else: self.w_conv = None flatten_dim = input_channels * np.prod(input_dim) if dense_hidden_dims is not None: w_dense_layers = Encoder.init_dense_layers( flatten_dim, dense_hidden_dims, self.act_func, dropout) self.w_dense = nn.Sequential(*w_dense_layers) final_dim = dense_hidden_dims[-1] else: self.w_dense = None final_dim = flatten_dim self.w_mu = nn.Linear(final_dim, latent_dim) if self.log_alpha is None: self.w_logvar = nn.Linear(final_dim, latent_dim)
[docs] @staticmethod def final_conv_dim(input_dim, kernels, paddings): """ Infer the size of eaxh sample after the convolutions bloc. """ all_dims = [np.asarray(input_dim)] for kernel, padding in zip(kernels, paddings): all_dims.append((all_dims[-1] - kernel + 2 * padding) / 2 + 1) return np.asarray(all_dims).astype(int)
[docs] @staticmethod def init_dense_layers(input_dim, hidden_dims, act_func, dropout, final_activation=True): """ Create the dense layers. """ layers = [] current_dim = input_dim for cnt, dim in enumerate(hidden_dims): if not final_activation and cnt == (len(hidden_dims) - 1): layers.append(nn.Linear(current_dim, dim)) else: layers.extend([ nn.Linear(current_dim, dim), act_func()]) if dropout > 0: layers.append(nn.Dropout(dropout)) current_dim = dim return layers
[docs] @staticmethod def init_conv_layers(input_channels, flts, act_func, dropout, ndim=1): """ Create the convolutional layers. """ conv_fn = getattr(nn, "Conv{0}d".format(ndim)) layers = [] current_channels = input_channels for cnt, n_filts in enumerate(flts): layers.extend([ conv_fn( current_channels, out_channels=n_filts, kernel_size=(3 if (cnt == (len(flts) - 1)) else 5), stride=2, padding=2), act_func()]) if dropout > 0: layers.append(nn.Dropout(dropout)) current_channels = n_filts return layers
[docs] def forward(self, x): """ The forward method. """ if self.w_conv is not None: out = self.w_conv(x) else: out = x out = torch.flatten(out, start_dim=1) if self.w_dense is not None: out = self.w_dense(out) z_mu = self.w_mu(out) if self.log_alpha is None: z_logvar = self.w_logvar(out) else: z_logvar = Encoder.compute_logvar(z_mu, self.log_alpha) return Normal(loc=z_mu, scale=z_logvar.exp().pow(0.5))
[docs] @staticmethod def compute_logvar(mu, log_alpha): """ Compute the log variance in case of sparsity contraints. """ return log_alpha + 2 * torch.log(torch.abs(mu) + 1e-8)
[docs]class Decoder(nn.Module): """ The decoder part of a VAE. """
[docs] def __init__(self, latent_dim, conv_flts, dense_hidden_dims, output_channels, output_dim, noise_out_logvar=-3, noise_fixed=True, act_func=None, final_activation=False, dropout=0): """ Init class. Parameters ---------- latent_dim: int the latent size. conv_flts: list of int the size of convolutional filters, if None do not include convolutional layers. dense_hidden_dims: list of int the size of dense hidden dimensions, if None do not include dense hidden layers. output_channels: int the number of output channels. output_dim: int or list of int the size of output. noise_out_logvar: float, default -3 the init output log var. noise_fixed: bool, default True estimate the the output log var. act_func: callable, default None the activation function. final_activation: bool, default False apply activation function to the final layer. dropout: float, default 0 define the dropout rate. """ super(Decoder, self).__init__() self.act_func = act_func or nn.ReLU self.output_channels = output_channels self.conv_flts = conv_flts if isinstance(output_dim, torch.Size): output_dim = list(output_dim) elif not isinstance(output_dim, list): output_dim = [output_dim] ndim = len(output_dim) if conv_flts is not None: self.all_dims = Encoder.final_conv_dim( output_dim, kernels=[5] * (len(conv_flts) - 1) + [3], paddings=[2] * len(conv_flts)) self.final_dim = self.all_dims[-1] self.all_dims = self.all_dims[::-1] flatten_dim = conv_flts[0] * np.prod(self.final_dim) else: self.final_dim = [-1] flatten_dim = output_channels * np.prod(output_dim) if dense_hidden_dims is None: dense_hidden_dims = [] w_dense_layers = Encoder.init_dense_layers( latent_dim, dense_hidden_dims + [flatten_dim], self.act_func, dropout, final_activation=not(not final_activation and conv_flts is None)) self.w_dense = nn.Sequential(*w_dense_layers) if conv_flts is not None: self.w_dense = nn.Sequential(*w_dense_layers) w_conv_layers = Decoder.init_conv_layers( conv_flts[0], conv_flts[1:] + [output_channels], self.act_func, dropout, ndim, final_activation=final_activation) self.w_conv = nn.Sequential(*w_conv_layers) else: self.w_conv = None self.w_out_logvar = torch.nn.Parameter( data=torch.FloatTensor( 1, output_channels, *output_dim).fill_(noise_out_logvar), requires_grad=(not noise_fixed))
[docs] @staticmethod def init_conv_layers(input_channels, flts, act_func, dropout, ndim=1, final_activation=True): """ Create the convolutional layers. """ convt_fn = getattr(nn, "ConvTranspose{0}d".format(ndim)) layers = [] current_channels = input_channels for cnt, n_flts in enumerate(flts): if not final_activation and cnt == (len(flts) - 1): layers.append(convt_fn( current_channels, out_channels=n_flts, kernel_size=(3 if (cnt == 0) else 5), stride=2, padding=2)) else: layers.extend([ convt_fn( current_channels, out_channels=n_flts, kernel_size=(3 if (cnt == 0) else 5), stride=2, padding=2), act_func()]) if dropout > 0: layers.append(nn.Dropout(dropout)) current_channels = n_flts return layers
[docs] def forward(self, z): """ The forward method. """ out = self.w_dense(z) if self.w_conv is not None: out = out.view(out.size(0), self.conv_flts[0], *self.final_dim) idx_layer = 0 for module in self.w_conv: out = module(out) # Restore orig tensor size (preserve autoencoder structure) if isinstance(module, nn.modules.conv._ConvNd): idx_layer += 1 orig_dim = self.all_dims[idx_layer] deltas = [] for idx, dim in enumerate(orig_dim[::-1]): delta_dim = dim - out.size(dim=-(idx + 1)) deltas.extend([delta_dim // 2, delta_dim - delta_dim // 2]) out = func.pad(out, deltas) else: out = out.view(out.size(0), self.output_channels, -1) return Normal(loc=out, scale=self.w_out_logvar.exp().pow(0.5))
[docs]@Networks.register @DeepLearningDecorator(family=("encoder", "vae")) class VAENet(nn.Module): """ The VAE architecture. Spatiotemporal Trajectories in Resting-state FMRI Revealed by Convolutional Variational Autoencoder, Xiaodi Zhang, Eric Maltbie, Shella Keilholz, bioRxiv 2021. Deep Variational Autoencoder for Modeleing functional brain networks and ADHD idetification, ISBI 2020. Sparse Multi-Channel Variational Autoencoder for the Joint Analysis of Heterogeneous Data, Luigi Antelmi, Nicholas Ayache, Philippe Robert, Marco Lorenzi, PMLR 2019. """
[docs] def __init__(self, input_channels, input_dim, conv_flts, dense_hidden_dims, latent_dim, noise_out_logvar=-3, noise_fixed=True, log_alpha=None, act_func=None, final_activation=False, dropout=0, sparse=False, encoder=None, decoder=None): """ Init class. Parameters ---------- input_channels: int the number of input channels. input_dim: int or list of int the size of input. conv_flts: list of int the size of convolutional filters, if None do not include convolutional layers. dense_hidden_dims: list of int the size of dense hidden dimensions, if None do not include dense hidden layers. latent_dim: int the latent dimension. noise_out_logvar: float, default -3 the init output log var. noise_fixed: bool, default True estimate the the output log var. log_alpha: nn.Parameter, default None dropout probabilities estimate. act_func: callable, default None the activation function. final_activation: bool, default False apply activation function to the final layer. dropout: float, default 0 define the dropout rate. sparse: bool, default False use sparsity contraint. encoder: nn.Module, default None a custom encoder. decoder: nn.Module, default None a custom decoder. """ super(VAENet, self).__init__() if isinstance(input_dim, tuple): input_dim = list(input_dim) self.latent_dim = latent_dim self.act_func = act_func if sparse: if log_alpha is None: self.log_alpha = nn.Parameter( torch.FloatTensor(1, self.latent_dim).normal_(0, 0.1)) else: self.log_alpha = log_alpha else: self.log_alpha = None encoder = encoder or Encoder decoder = decoder or Decoder self.encode = encoder( input_channels, input_dim, conv_flts, dense_hidden_dims, latent_dim, act_func=act_func, dropout=dropout, log_alpha=self.log_alpha) if conv_flts is not None: dec_conv_flts = conv_flts[::-1] else: dec_conv_flts = None if dense_hidden_dims is not None: dec_dense_hidden_dims = dense_hidden_dims[::-1] else: dec_dense_hidden_dims = None self.decode = decoder( latent_dim, dec_conv_flts, dec_dense_hidden_dims, input_channels, input_dim, noise_out_logvar=noise_out_logvar, noise_fixed=noise_fixed, act_func=act_func, final_activation=final_activation, dropout=dropout)
# TODO: Not working well # self.kernel_initializer()
[docs] def forward(self, x): """ The forward method. """ q = self.encode(x) posterior = q z = self.reparameterize(q) p = self.decode(z) return p, {"q": q, "z": z, "model": self}
[docs] def reparameterize(self, q): """ Implement the reparametrization trick. """ if self.training: z = q.rsample() else: z = q.loc return z
[docs] @staticmethod def p_to_prediction(p): """ Get the prediction from various types of distributions. """ if isinstance(p, Normal): pred = p.loc.cpu().detach().numpy() elif isinstance(p, Bernoulli): pred = p.probs.cpu().detach().numpy() else: raise NotImplementedError return pred
[docs] def reconstruct(self, x, sample=False, dropout_threshold=None): """ Reconstruct a new data from a given input with or without resampling. """ with torch.no_grad(): q = self.encode(x) posterior = q if sample: z = posterior.sample() else: z = posterior.loc if dropout_threshold is not None: z = self.apply_threshold(z, dropout_threshold) p = self.decode(z) return self.p_to_prediction(p)
[docs] def generate(self, z=None, device=None): """ Generate a new data from a given sample or a random one. """ device = device or torch.device("cpu") with torch.no_grad(): if z is None: z = Normal(loc=torch.zeros(1, self.latent_dim), scale=1).sample() z = z.to(device) p = self.decode(z) return VAENet.p_to_prediction(p)
[docs] def apply_threshold(self, z, threshold, keep_dims=True, reorder=False): """ Threshold the latent samples based on the estimated dropout probabilities. """ assert(threshold <= 1.0) order = torch.argsort(self.dropout).squeeze() keep = (self.dropout < threshold).squeeze() if keep_dims: z[:, ~keep] = 0 else: z = z[:, keep] order = torch.argsort(self.dropout[keep]).squeeze() if reorder: z = z[:, order] return z
@property def dropout(self): """ Compute the dropout probabilities. """ if self.log_alpha is not None: alpha = torch.exp(self.log_alpha.detach()) return alpha / (alpha + 1) else: raise NotImplementedError
[docs] def kernel_initializer(self): """ Init network weights. """ for module in self.modules(): init_weight(module, self.act_func)

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.