Menu

Helper Module for Deep Learning.

Source code for pynet.models.vae.pmvae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Pathway Modules Variational Auto-Encoder (pmVAE).

[1] pmVAE: Learning Interpretable Single-Cell Representations with Pathway
Modules, Gilles Gut, biorxiv 2021.

Code: https://github.com/ratschlab/pmvae
"""

# Imports
import logging
import math
import numpy as np
import pandas as pd
from scipy.linalg import block_diag
import torch
import torch.nn as nn
import torch.nn.functional as func
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks, init_weight


[docs]@Networks.register @DeepLearningDecorator(family=("encoder", "vae", "genetic")) class PMVAE(nn.Module):
[docs] def __init__(self, membership_mask, latent_dim, hidden_layers, bias_last_layer=False, add_auxiliary_module=True, terms=None, activation=None): """ pmVAE constructs a pathway-factorized latent space. Parameters ---------- membership_mask: bool array (pathways, genes) a binary mask encoding which genes belong to wich pathways. latent_dim: int the dimension of each module latent space. hidden_layers: list of int the dimension of each module encoder/decoder hidden layer. bias_last_layer: bool, default False use a bias term on the final decoder output. add_auxiliary_module: bool, default True include a fully connected pathway module. terms: list of str (pathways, ), default None the pathway names. activation: klass, default None the activation function. """ super(PMVAE, self).__init__() self.n_annotated_modules, self.num_feats = membership_mask.shape if isinstance(membership_mask, pd.DataFrame): terms = membership_mask.index membership_mask = membership_mask.values self.add_auxiliary_module = add_auxiliary_module if add_auxiliary_module: membership_mask = np.vstack( (membership_mask, np.ones_like(membership_mask[0]))) if terms is not None: terms = list(terms) + ["AUXILIARY"] self.activation = activation or nn.ELU # Then encoder maps the input data to the latent space. self.encoder = PMVAE.build_encoder( membership_mask, hidden_layers, latent_dim, self.activation, batch_norm=True) # The decoder maps a code to the output of each module. # The merger connects each module output to its genes. self.decoder, self.merger = PMVAE.build_decoder( membership_mask, hidden_layers, latent_dim, self.activation, batch_norm=True, bias_last_layer=bias_last_layer) self.membership_mask = membership_mask self.module_isolation_mask = PMVAE.build_module_isolation_mask( self.membership_mask.shape[0], hidden_layers[-1]) self._latent_dim = latent_dim self._hidden_layers = hidden_layers assert len(terms) == len(self.membership_mask) self.terms = list(terms) self.kernel_initializer()
[docs] def kernel_initializer(self): """ Init network weights. """ for module in self.modules(): if isinstance(module, MaskedLinear): fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out( module.weight) limit = math.sqrt(6 / fan_in) nn.init.uniform_(module.weight, a=-limit, b=limit) if module.bias is not None: nn.init.constant_(module.bias, 0)
[docs] @staticmethod def build_base_masks(membership_mask, hidden_layers, latent_dim): """ Builds the masks used by the encoders/decoders. Parameters ---------- membership_mask: bool array (pathways, genes) a binary mask encoding which genes belong to wich pathways. latent_dim: int the dimension of each module latent space. hidden_layers: list of int the dimension of each module encoder/decoder hidden layer. Returns ------- base: list of array pathway mask assigns genes to pathway modules, and separation masks keep modules separated. Encoder modifies the last separation mask to give mu/logvar, and the decoder reverses and transposes the masks. """ n_modules, n_feats = membership_mask.shape base = [] base.append(PMVAE.build_pathway_mask( n_feats, membership_mask, hidden_layers[0])) dims = hidden_layers + [latent_dim] for input_dim, output_dim in zip(dims[:-1], dims[1:]): base.append(PMVAE.build_separation_mask( input_dim, output_dim, n_modules)) base = [mask.astype(np.float32) for mask in base] return base
[docs] @staticmethod def build_pathway_mask(nfeats, membership_mask, hidden_layers): """ Connects genes to pathway modules. Repeats the membership mask for each module input node. See M in Methods 2.2. """ return np.repeat(membership_mask, hidden_layers, axis=0).T
[docs] @staticmethod def build_separation_mask(input_dim, out_put_dim, nmodules): """ Removes connections betweens pathway modules. Block diagonal matrix, see Sigma in Methods 2.2. """ blocks = [np.ones((input_dim, out_put_dim))] * nmodules return block_diag(*blocks)
[docs] @staticmethod def build_module_isolation_mask(nmodules, module_output_dim): """ Isolates a single module for gradient steps. Used for the local reconstruciton terms, drops all modules except one. """ blocks = [np.ones((1, module_output_dim))] * nmodules return block_diag(*blocks)
[docs] @staticmethod def build_encoder(membership_mask, hidden_layers, latent_dim, activation, batch_norm=True): """ Build the encoder module. """ masks = PMVAE.build_base_masks( membership_mask, hidden_layers, latent_dim) masks[-1] = np.hstack((masks[-1], masks[-1])) masks = [torch.from_numpy(mask.T) for mask in masks] modules = [] in_features = membership_mask.shape[1] for cnt, mask in enumerate(masks): out_features = mask.shape[0] modules.append(MaskedLinear(in_features, out_features, mask)) if batch_norm: modules.append(nn.BatchNorm1d(out_features, eps=0.001, momentum=0.99)) if cnt != (len(masks) - 1): modules.append(activation()) in_features = out_features encoder = nn.Sequential(*modules) return encoder
[docs] @staticmethod def build_decoder(membership_mask, hidden_layers, latent_dim, activation, batch_norm=True, bias_last_layer=False): """ Build the decoder/merger modules. """ masks = PMVAE.build_base_masks( membership_mask, hidden_layers, latent_dim) in_features = masks[-1].shape[1] masks = [torch.from_numpy(mask) for mask in masks[::-1]] modules = [] for mask in masks[:-1]: out_features = mask.shape[0] modules.append(MaskedLinear(in_features, out_features, mask)) if batch_norm: modules.append(nn.BatchNorm1d(out_features, eps=0.001, momentum=0.99)) modules.append(activation()) in_features = out_features decoder = nn.Sequential(*modules) merger = MaskedLinear(in_features, masks[-1].shape[0], masks[-1], bias=bias_last_layer) return decoder, merger
[docs] def encode(self, x): """ Computes the inference distribution q(z | x). Parameters ---------- x: torch.Tensor (batch_size, data_size) the input data. Returns ------- q(z | x): @callable the distribution q(z | x) with shape (batch_size, latent_dim. """ params = self.encoder(x) mu, logvar = torch.split( params, split_size_or_sections=(params.size(dim=1) // 2), dim=1) return mu, logvar
[docs] def decode(self, z): """ Computes the generative distribution p(x | z). Parameters ---------- z: torch.Tensor (batch_size, latent_dim) the stochastic latent state z. Returns ------- p(x | z): @callable the distribution p(x | z) with shape (batch_size, data_size). """ module_outputs = self.decoder(z) global_recon = self.merger(module_outputs, **kwargs) return global_recon
[docs] def reparametrize(self, mu, logvar): """ Implement the reparametrization trick. """ eps = torch.randn_like(logvar) return mu + torch.exp(logvar / 2.) * eps
[docs] def forward(self, x): """ The forward method. """ mu, logvar = self.encode(x) z = self.reparametrize(mu, logvar) module_outputs = self.decoder(z) global_recon = self.merger(module_outputs) return global_recon, {"z": z, "module_outputs": module_outputs, "mu": mu, "logvar": logvar, "model": self}
[docs] def get_masks_for_local_losses(self): """ Get module/pathway associated masks. """ if self.add_auxiliary_module: return zip(self.membership_mask[:-1], self.module_isolation_mask[:-1]) return zip(self.membership_mask, self.module_isolation_mask)
[docs] def latent_space_names(self, terms=None): """ Get latent space associated names. """ terms = self.terms or terms assert terms is not None, "Need to specify gene set terms." if (self.add_auxiliary_module and (len(terms) == self.n_annotated_modules)): terms = list(terms) + ["AUXILIARY"] z = self._latent_dim repeated_terms = np.repeat(terms, z) index = np.tile(range(z), len(terms)).astype(str) latent_dim_names = map("-".join, zip(repeated_terms, index)) return list(latent_dim_names)
class MaskedLinear(nn.Linear): """ Masked Linear module. """ def __init__(self, in_features, out_features, mask, *args, **kwargs): """ Init class. Parameters ---------- in_features: int size of each input sample. out_features: int size of each output sample. mask: torch.Tensor mask weights with this boolean tensor. """ super(MaskedLinear, self).__init__( in_features, out_features, *args, **kwargs) self.mask = nn.Parameter(mask, requires_grad=False) def forward(self, inputs): """ Forward method. """ assert self.mask.shape == self.weight.shape return func.linear(inputs, self.weight * self.mask, self.bias)

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.