Menu

Helper Module for Deep Learning.

Source code for pynet.models.attention

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
The Spatiotemporal Attention Autoencoder network (STAAENet).
"""

# Imports
import logging
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks
import torch
import torch.nn as nn
import torch.nn.functional as func

# Global parameters
logger = logging.getLogger("pynet")


[docs]@Networks.register @DeepLearningDecorator(family=("encoder", )) class STAAENet(nn.Module): """ SpatioTemporal Attention AutoEncoder (STAAE). """
[docs] def __init__(self, input_dim, nodecoding=False): """ Init class. Parameters ---------- input_dim: int the input dimension. nodecoding: bool, default False if set do not apply the decoding. """ super(STAAENet, self).__init__() self.input_dim = input_dim self.nodecoding = nodecoding # Build Encoder self.enc_dense1 = nn.Sequential( nn.Linear(self.input_dim, 512), nn.Tanh()) self.enc_dense2 = nn.Sequential( nn.Linear(512, 128), nn.Tanh()) self.enc_attention = SelfAttention(128, 64) self.encoder = nn.Sequential( self.enc_dense1, self.enc_dense2, self.enc_attention) # Build Decoder self.dec_attention = SelfAttention(64, 128) self.dec_dense1 = nn.Sequential( nn.Linear(128, 512), nn.Tanh()) self.dec_dense2 = nn.Sequential( nn.Linear(512, self.input_dim), nn.Tanh()) self.decoder = nn.Sequential( self.dec_attention, self.dec_dense1, self.dec_dense2)
[docs] def encode(self, x): """ Encodes the input by passing through the encoder network and returns the latent codes. Parameters ---------- x: Tensor, (N, C, F) input tensor to encode. Returns ------- mu: Tensor (N, D) mean of the latent Gaussian. logvar: Tensor (N, D) standard deviation of the latent Gaussian. """ logger.debug("Encode...") self.debug("input", x) x = self.encoder(x) return x
[docs] def decode(self, x): """ Maps the given latent codes onto the image space. Parameters ---------- x: Tensor (N, D) sample from the distribution having latent parameters mu, var. Returns ------- x: Tensor, (N, C, F) the prediction. """ logger.debug("Decode...") self.debug("x", x) x = self.decoder(x) self.debug("decoded", x) return x
[docs] def forward(self, x, **kwargs): logger.debug("STAAE Net...") code = self.encode(x) if self.nodecoding: return code else: return self.decode(code)
[docs] @staticmethod def debug(name, tensor): """ Print debug message. Parameters ---------- name: str the tensor name in the displayed message. tensor: Tensor a pytorch tensor. """ logger.debug(" {3}: {0} - {1} - {2}".format( tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]class SelfAttention(nn.Module):
[docs] def __init__(self, input_dim, output_dim): super(SelfAttention, self).__init__() self.output_dim = output_dim self.kernel = nn.Parameter( torch.zeros(3, input_dim, output_dim), requires_grad=True) nn.init.uniform_(self.kernel)
[docs] def forward(self, x, **kwargs): logger.debug("Self Attention...") self.debug("x", x) self.debug("kernel", self.kernel) WQ = torch.matmul(x, self.kernel[0]) self.debug("WQ", WQ) WK = torch.matmul(x, self.kernel[1]) self.debug("WK", WK) WV = torch.matmul(x, self.kernel[2]) self.debug("WV", WV) QK = torch.matmul(WQ, WK.permute(0, 2, 1)) QK = QK / (self.output_dim ** 0.5) self.debug("QK", QK) QK = torch.softmax(QK, dim=0) self.debug("QK", QK) V = torch.matmul(QK, WV) self.debug("V", V) return V
[docs] def compute_output_shape(self, input_shape): return (input_shape[0], input_shape[1], self.output_dim)
[docs] @staticmethod def debug(name, tensor): """ Print debug message. Parameters ---------- name: str the tensor name in the displayed message. tensor: Tensor a pytorch tensor. """ logger.debug(" {3}: {0} - {1} - {2}".format( tensor.shape, tensor.get_device(), tensor.dtype, name))

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.