Helper Module for Deep Learning.
Source code for pynet.models.nvnet
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2019
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
NvNet: combination of Vnet and VAE (variation auto-encoder).
"""
# Imports
import logging
import torch
import torch.nn as nn
import torch.nn.functional as func
import numpy as np
from torch.distributions import Normal
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks
# Global parameters
logger = logging.getLogger("pynet")
[docs]@Networks.register
@DeepLearningDecorator(family="segmenter")
class NvNet(nn.Module):
""" NvNet: combination of Vnet and VAE (variation auto-encoder).
The variational auto-encoder branch reconstruct the input image jointly
with segmentation in order to regularized the shared encoder.
Reference: https://arxiv.org/pdf/1810.11654.pdf.
Code: https://github.com/athon2/BraTS2018_NvNet.
"""
[docs] def __init__(self, input_shape, in_channels, num_classes,
activation="relu", normalization="group_normalization",
mode="trilinear", with_vae=True):
""" Init class.
Parameters
----------
input_shape: uplet
the tensor data shape (X, Y, Z).
in_channels: int
number of channels in the input tensor.
num_classes: int
the number of features in the output segmentation map.
activation: str, default 'relu'
the activation function.
normalization: str, default 'group_normalization'
the normalization function.
mode: str, default 'trilinear'
the interpolation mode.
with_vae: bool, default True
enable/disable vae penalty.
"""
# Inheritance
nn.Module.__init__(self)
# Check inputs
if activation not in ("relu", "elu"):
raise ValueError(
"'{}' is not a valid activation. Only 'relu' "
"and 'elu' are allowed.".format(activation))
if normalization not in ("group_normalization"):
raise ValueError(
"'{}' is not a valid normalization. Only "
"'group_normalization' is allowed.".format(normalization))
if mode not in ("nearest", "linear", "bilinear", "bicubic",
"trilinear", "area"):
raise ValueError(
"'{}' is not a valid interpolation mode: see "
"'torch.nn.functional.interpolate'for a list of allowed "
"modes.".format(mode))
# Declare class parameters
self.input_shape = input_shape
self.num_classes = num_classes
self.in_channels = in_channels
self.activation = activation
self.normalization = normalization
self.mode = mode
self.with_vae = with_vae
# Encoder Blocks: encoder parts uses ResNet blocks (two 3x3x3 conv,
# group normalization (better than batch norm for small batch size),
# and ReLU), followed by additive identity skip connection.
# A Downsizing of 2 is performed with strided convolutions.
# Features increase by two at each level.
self.in_conv0 = DownSampling(
in_channels=self.in_channels, out_channels=32, stride=1,
kernel_size=3, dropout_rate=0.2, bias=True)
self.en_block0 = EncoderBlock(
in_channels=32, out_channels=32, kernel_size=3,
activation=activation, normalization=normalization)
self.en_down1 = DownSampling(
in_channels=32, out_channels=64, stride=2, kernel_size=3)
self.en_block1_0 = EncoderBlock(
in_channels=64, out_channels=64, kernel_size=3,
activation=activation, normalization=normalization)
self.en_block1_1 = EncoderBlock(
in_channels=64, out_channels=64, kernel_size=3,
activation=activation, normalization=normalization)
self.en_down2 = DownSampling(
in_channels=64, out_channels=128, stride=2, kernel_size=3)
self.en_block2_0 = EncoderBlock(
in_channels=128, out_channels=128, kernel_size=3,
activation=activation, normalization=normalization)
self.en_block2_1 = EncoderBlock(
in_channels=128, out_channels=128, kernel_size=3,
activation=activation, normalization=normalization)
self.en_down3 = DownSampling(
in_channels=128, out_channels=256, stride=2, kernel_size=3)
self.en_block3_0 = EncoderBlock(
in_channels=256, out_channels=256, kernel_size=3,
activation=activation, normalization=normalization)
self.en_block3_1 = EncoderBlock(
in_channels=256, out_channels=256, kernel_size=3,
activation=activation, normalization=normalization)
self.en_block3_2 = EncoderBlock(
in_channels=256, out_channels=256, kernel_size=3,
activation=activation, normalization=normalization)
self.en_block3_3 = EncoderBlock(
in_channels=256, out_channels=256, kernel_size=3,
activation=activation, normalization=normalization)
# Decoder Blocks: similar to encoder but with single block
# per level.
# Upsizing reduced the number of features by 2 (1x1x1 con) and
# doubled spatial dimension (trilinear interpolation).
# Skip connection of the same encoded level is added.
# Final result is obtained by a 1x1x1 conv and a sigmoid function.
self.de_up2 = LinearUpSampling(
in_channels=256, out_channels=128, mode=mode)
self.de_block2 = DecoderBlock(
in_channels=128, out_channels=128, kernel_size=3,
activation=activation, normalization=normalization)
self.de_up1 = LinearUpSampling(
in_channels=128, out_channels=64, mode=self.mode)
self.de_block1 = DecoderBlock(
in_channels=64, out_channels=64, kernel_size=3,
activation=activation, normalization=normalization)
self.de_up0 = LinearUpSampling(
in_channels=64, out_channels=32, mode=mode)
self.de_block0 = DecoderBlock(
in_channels=32, out_channels=32, kernel_size=3,
activation=activation, normalization=normalization)
self.de_end = OutputTransition(
in_channels=32, out_channels=num_classes)
# Variational Auto-Encoder: reduce the input to a low dimensional
# space of 256 (128 to represent the mean and 128 to represent the std)
# A sample is drawn from the Gaussian distribution and reconstructed
# into the input image shape following the decoder architecture
# without interlevel skip connections.
if self.with_vae:
self.shapes = self._downsample_shape(
self.input_shape,
nb_iterations=4,
scale_factor=2)
logger.debug("Shapes: {0}".format(self.shapes))
self.vae = VAE(
shapes=self.shapes,
in_channels=256,
out_channels=self.in_channels,
kernel_size=3,
activation=activation,
normalization=normalization,
mode=mode)
def _downsample_shape(self, shape, nb_iterations=1, scale_factor=2):
shape = np.asarray(shape)
all_shapes = [shape.astype(int).tolist()]
for idx in range(nb_iterations):
shape = np.ceil(shape / scale_factor)
all_shapes.append(shape.astype(int).tolist())
return all_shapes
[docs] def forward(self, x):
logger.debug("NVnet...")
logger.debug("Tensor: {0}".format(x.shape))
out_init = self.in_conv0(x)
logger.debug("Initial conv: {0}".format(out_init.shape))
out_en0 = self.en_block0(out_init)
out_en1 = self.en_block1_1(self.en_block1_0(self.en_down1(out_en0)))
logger.debug("Encoding block 1: {0}".format(out_en1.shape))
out_en2 = self.en_block2_1(self.en_block2_0(self.en_down2(out_en1)))
logger.debug("Encoding block 2: {0}".format(out_en2.shape))
out_en3 = self.en_block3_3(self.en_block3_2(self.en_block3_1(
self.en_block3_0(self.en_down3(out_en2)))))
logger.debug("Encoding block 3: {0}".format(out_en3.shape))
out_de2 = self.de_block2(self.de_up2(out_en3, out_en2))
logger.debug("Decoding block 1: {0}".format(out_de2.shape))
out_de1 = self.de_block1(self.de_up1(out_de2, out_en1))
logger.debug("Decoding block 2: {0}".format(out_de1.shape))
out_de0 = self.de_block0(self.de_up0(out_de1, out_en0))
logger.debug("Decoding block 3: {0}".format(out_de0.shape))
out_end = self.de_end(out_de0)
logger.debug("Final conv: {0}".format(out_end.shape))
if self.with_vae:
out_vae, vae_kwargs = self.vae(out_en3)
logger.debug("VAE: {0}".format(out_vae.shape))
out_final = torch.cat((out_end, out_vae), 1)
return out_final, vae_kwargs
else:
return out_end
[docs]class DownSampling(nn.Module):
""" A convolution and a padding.
"""
[docs] def __init__(self, in_channels, out_channels, stride=2, kernel_size=3,
padding=1, dropout_rate=None, bias=False):
super(DownSampling, self).__init__()
self.dropout_flag = False
self.conv1 = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias)
if dropout_rate is not None:
self.dropout_flag = True
self.dropout = nn.Dropout3d(dropout_rate, inplace=True)
[docs] def forward(self, x):
out = self.conv1(x)
if self.dropout_flag:
out = self.dropout(out)
return out
[docs]class EncoderBlock(nn.Module):
""" Encoder block
"""
[docs] def __init__(self, in_channels, out_channels, stride=1, kernel_size=3,
padding=1, num_groups=8, activation="relu",
normalization="group_normalization"):
super(EncoderBlock, self).__init__()
if normalization == "group_normalization":
self.norm1 = nn.GroupNorm(
num_groups=num_groups,
num_channels=in_channels)
self.norm2 = nn.GroupNorm(
num_groups=num_groups,
num_channels=in_channels)
if activation == "relu":
self.actv1 = nn.ReLU(inplace=True)
self.actv2 = nn.ReLU(inplace=True)
elif activation == "elu":
self.actv1 = nn.ELU(inplace=True)
self.actv2 = nn.ELU(inplace=True)
self.conv1 = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
self.conv2 = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
[docs] def forward(self, x):
residual = x
out = self.norm1(x)
out = self.actv1(out)
out = self.conv1(out)
out = self.norm2(out)
out = self.actv2(out)
out = self.conv2(out)
out += residual
return out
[docs] def debug(self, name, tensor):
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]class LinearUpSampling(nn.Module):
""" Interpolate to upsample.
"""
[docs] def __init__(self, in_channels, out_channels, scale_factor=2,
mode="trilinear", align_corners=True):
super(LinearUpSampling, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
self.conv1 = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1)
self.conv2 = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1)
[docs] def forward(self, x, skipx=None, cat=True):
out = self.conv1(x)
if skipx is not None:
if isinstance(skipx, torch.Tensor):
shape = skipx.shape[2:]
else:
shape = skipx
out = nn.functional.interpolate(
out,
size=shape,
mode=self.mode,
align_corners=self.align_corners)
else:
out = nn.functional.interpolate(
out,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners)
if cat and skipx is not None:
out = torch.cat((out, skipx), 1)
out = self.conv2(out)
return out
[docs]class DecoderBlock(EncoderBlock):
""" Decoder block.
"""
[docs] def __init__(self, in_channels, out_channels, stride=1, kernel_size=3,
padding=1, num_groups=8, activation="relu",
normalization="group_normalization"):
super(DecoderBlock, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=kernel_size,
padding=padding,
num_groups=num_groups,
activation=activation,
normalization=normalization)
[docs]class OutputTransition(nn.Module):
""" Decoder output layer: output the prediction of the segmentation.
"""
[docs] def __init__(self, in_channels, out_channels):
super(OutputTransition, self).__init__()
self.conv1 = nn.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1)
[docs]class VDResampling(nn.Module):
""" Variational Auto-Encoder Resampling block.
"""
[docs] def __init__(self, in_channels=256, out_channels=256,
dense_features=(10, 12, 8), stride=2, kernel_size=3,
padding=1, activation="relu",
normalization="group_normalization"):
super(VDResampling, self).__init__()
self.mid_chans = int(in_channels / 2)
self.dense_features = dense_features
if normalization == "group_normalization":
self.gn1 = nn.GroupNorm(
num_groups=8,
num_channels=in_channels)
if activation == "relu":
self.actv1 = nn.ReLU(inplace=True)
self.actv2 = nn.ReLU(inplace=True)
elif activation == "elu":
self.actv1 = nn.ELU(inplace=True)
self.actv2 = nn.ELU(inplace=True)
self.conv1 = nn.Conv3d(
in_channels=in_channels, out_channels=16, kernel_size=kernel_size,
stride=stride, padding=padding)
self.dense1 = nn.Linear(
in_features=(
16 * dense_features[0] * dense_features[1] *
dense_features[2]),
out_features=in_channels)
self.vdraw = VDraw(in_channels=in_channels, out_channels=128)
self.dense2 = nn.Linear(
in_features=self.mid_chans,
out_features=(
self.mid_chans * dense_features[0] * dense_features[1] *
dense_features[2]))
self.dense = nn.Linear(
in_features=self.mid_chans,
out_features=(
self.mid_chans * dense_features[0] * dense_features[1] *
dense_features[2]))
self.up0 = LinearUpSampling(self.mid_chans, out_channels)
[docs] def forward(self, x):
logger.debug("Resampling tensor: {0}".format(x.shape))
n_samples = x.shape[0]
out = self.gn1(x)
out = self.actv1(out)
out = self.conv1(out)
logger.debug("Resampling VD 1.1: {0}".format(out.shape))
out = out.view(-1, self.num_flat_features(out))
logger.debug("Resampling VD 1.2: {0}".format(out.shape))
out_vd = self.dense1(out)
logger.debug("Resampling VD 2: {0}".format(out_vd.shape))
q = self.vdraw(out_vd)
z = self.vdraw.reparametrization(q)
logger.debug("Resampling VDraw: {0}".format(z.shape))
out = self.dense2(z)
out = self.actv2(out)
logger.debug("Resampling VU 1.1: {0}".format(out.shape))
out = out.view((n_samples, self.mid_chans, self.dense_features[0],
self.dense_features[1], self.dense_features[2]))
logger.debug("Resampling VU 1.2: {0}".format(out.shape))
out = self.up0(out, x, cat=False)
logger.debug("Resampling VU 2: {0}".format(out.shape))
return out, {"z": z, "q": q}
[docs] def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
[docs]class VDraw(nn.Module):
""" Generate a Gaussian distribution with the given mean(128-d) and
std(128-d).
"""
[docs] def __init__(self, in_channels=256, out_channels=128):
super(VDraw, self).__init__()
self.w_mu = nn.Linear(
in_features=in_channels,
out_features=out_channels)
self.w_logvar = nn.Linear(
in_features=in_channels,
out_features=out_channels)
[docs] def forward(self, x):
z_mu = self.w_mu(x)
z_logvar = self.w_logvar(x)
return Normal(loc=z_mu, scale=z_logvar.exp().pow(0.5))
[docs]class VDecoderBlock(nn.Module):
""" Variational Decoder block.
"""
[docs] def __init__(self, in_channels, out_channels, kernel_size=3,
activation="relu", normalization="group_normalization",
mode="trilinear"):
super(VDecoderBlock, self).__init__()
self.up = LinearUpSampling(
in_channels=in_channels,
out_channels=out_channels,
mode=mode)
self.de_block = DecoderBlock(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
activation=activation,
normalization=normalization)
[docs] def forward(self, x, shape=None):
if shape is None:
out = self.up(x)
else:
out = self.up(x, skipx=shape, cat=False)
out = self.de_block(out)
return out
[docs]class VAE(nn.Module):
""" Variational Auto-Encoder: to group the features extracted by Encoder.
"""
[docs] def __init__(self, shapes, in_channels=256, out_channels=4, kernel_size=3,
activation="relu", normalization="group_normalization",
mode="trilinear"):
super(VAE, self).__init__()
self.shapes = shapes
self.vd_resample = VDResampling(
in_channels=in_channels,
out_channels=in_channels,
dense_features=shapes[-1],
stride=2,
kernel_size=kernel_size)
self.vd_block2 = VDecoderBlock(
in_channels=in_channels,
out_channels=in_channels//2,
kernel_size=kernel_size,
activation=activation,
normalization=normalization,
mode=mode)
self.vd_block1 = VDecoderBlock(
in_channels=in_channels//2,
out_channels=in_channels//4,
kernel_size=kernel_size,
activation=activation,
normalization=normalization,
mode=mode)
self.vd_block0 = VDecoderBlock(
in_channels=in_channels//4,
out_channels=in_channels//8,
kernel_size=kernel_size,
activation=activation,
normalization=normalization,
mode=mode)
self.vd_end = nn.Conv3d(
in_channels=in_channels//8,
out_channels=out_channels,
kernel_size=1)
[docs] def forward(self, x):
logger.debug("Variational decoder tensor: {0}".format(x.shape))
out, resample_kwargs = self.vd_resample(x)
logger.debug("Variational decoder resampling: {0}".format(
out.shape))
out = self.vd_block2(out, self.shapes[2])
logger.debug("Variational decoder 1: {0}".format(out.shape))
out = self.vd_block1(out, self.shapes[1])
logger.debug("Variational decoder 2: {0}".format(out.shape))
out = self.vd_block0(out, self.shapes[0])
logger.debug("Variational decoder 3: {0}".format(out.shape))
out = self.vd_end(out)
logger.debug("Variational decoder final conv: {0}".format(out.shape))
return out, resample_kwargs
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.