Helper Module for Deep Learning.
Source code for pynet.models.braingengan
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
3D MRI Brain Generation with Generative Adversarial Networks (BGGAN) with
Variational Auto Encoder (VAE).
"""
# Imports
import logging
import collections
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as func
from pynet.utils import Networks
# Global parameters
logger = logging.getLogger("pynet")
[docs]@Networks.register
class BGDiscriminator(nn.Module):
""" This is the discriminator part of the BGGAN.
"""
[docs] def __init__(self, in_shape, in_channels=1, out_channels=1,
start_filts=64, with_logit=True):
""" Init class.
Parameters
----------
in_shape: uplet
the input tensor data shape (X, Y, Z).
in_channels: int, default 1
number of channels in the input tensor.
out_channels: int, default 1
number of channels in the output tensor.
start_filts: int, default 64
number of convolutional filters for the first conv.
with_logit: bool, default True
apply the logit function to the result.
"""
super(BGDiscriminator, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.start_filts = start_filts
self.with_logit = with_logit
self.in_shape = in_shape
self.shapes = _downsample_shape(
self.in_shape, nb_iterations=4, scale_factor=2)
self.conv1 = nn.Conv3d(
self.in_channels, self.start_filts, kernel_size=4, stride=2,
padding=1)
self.conv2 = nn.Conv3d(
self.start_filts, self.start_filts * 2, kernel_size=4, stride=2,
padding=1)
self.bn2 = nn.BatchNorm3d(self.start_filts * 2)
self.conv3 = nn.Conv3d(
self.start_filts * 2, self.start_filts * 4, kernel_size=4,
stride=2, padding=1)
self.bn3 = nn.BatchNorm3d(self.start_filts * 4)
self.conv4 = nn.Conv3d(
self.start_filts * 4, self.start_filts * 8, kernel_size=4,
stride=2, padding=1)
self.bn4 = nn.BatchNorm3d(self.start_filts * 8)
self.conv5 = nn.Conv3d(
self.start_filts * 8, self.out_channels,
kernel_size=self.shapes[-1], stride=1, padding=0)
[docs] def forward(self, x):
logger.debug("BGGAN Discriminator...")
self.debug("input", x)
h1 = func.leaky_relu(self.conv1(x), negative_slope=0.2)
self.debug("conv1", h1)
h2 = func.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2)
self.debug("conv2", h2)
h3 = func.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2)
self.debug("conv3", h3)
h4 = func.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2)
self.debug("conv4", h4)
h5 = self.conv5(h4)
self.debug("conv5", h5)
if self.with_logit:
output = torch.sigmoid(h5.view(h5.size(0), -1))
self.debug("output", output)
else:
output = h5
logger.debug("Done.")
return output
[docs] def debug(self, name, tensor):
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]@Networks.register
class BGEncoder(nn.Module):
""" This is the encoder part of the BGGAN.
"""
[docs] def __init__(self, in_shape, in_channels=1, start_filts=64,
latent_dim=1000):
""" Init class.
Parameters
----------
in_shape: uplet
the input tensor data shape (X, Y, Z).
in_channels: int, default 1
number of channels in the input tensor.
start_filts: int, default 64
number of convolutional filters for the first conv.
latent_dim: int, default 1000
the latent variable sizes.
"""
super(BGEncoder, self).__init__()
self.in_channels = in_channels
self.start_filts = start_filts
self.latent_dim = latent_dim
self.in_shape = in_shape
self.shapes = _downsample_shape(
self.in_shape, nb_iterations=4, scale_factor=2)
self.dense_features = np.prod(self.shapes[-1])
logger.debug("BGGAN Encoder shapes: {0}".format(self.shapes))
self.conv1 = nn.Conv3d(
self.in_channels, self.start_filts, kernel_size=4, stride=2,
padding=1)
self.conv2 = nn.Conv3d(
self.start_filts, self.start_filts * 2, kernel_size=4, stride=2,
padding=1)
self.bn2 = nn.BatchNorm3d(self.start_filts * 2)
self.conv3 = nn.Conv3d(
self.start_filts * 2, self.start_filts * 4, kernel_size=4,
stride=2, padding=1)
self.bn3 = nn.BatchNorm3d(self.start_filts * 4)
self.conv4 = nn.Conv3d(
self.start_filts * 4, self.start_filts * 8, kernel_size=4,
stride=2, padding=1)
self.bn4 = nn.BatchNorm3d(self.start_filts * 8)
self.mean = nn.Sequential(
nn.Linear(self.start_filts * 8 * self.dense_features, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Linear(2048, self.latent_dim))
self.logvar = nn.Sequential(
nn.Linear(self.start_filts * 8 * self.dense_features, 2048),
nn.BatchNorm1d(2048),
nn.ReLU(),
nn.Linear(2048, self.latent_dim))
[docs] def forward(self, x):
logger.debug("BGGAN Encoder...")
batch_size = x.size(0)
logger.debug(" batch_size: {0}".format(batch_size))
self.debug("input", x)
h1 = func.leaky_relu(self.conv1(x), negative_slope=0.2)
self.debug("conv1", h1)
h2 = func.leaky_relu(self.bn2(self.conv2(h1)), negative_slope=0.2)
self.debug("conv2", h2)
h3 = func.leaky_relu(self.bn3(self.conv3(h2)), negative_slope=0.2)
self.debug("conv3", h3)
h4 = func.leaky_relu(self.bn4(self.conv4(h3)), negative_slope=0.2)
self.debug("conv4", h4)
mean = self.mean(h4.view(batch_size, -1))
self.debug("mean", mean)
logvar = self.logvar(h4.view(batch_size, -1))
self.debug("logvar", logvar)
std = logvar.mul(0.5).exp_()
reparametrized_noise = Variable(
torch.randn((batch_size, self.latent_dim))).to(x.device)
reparametrized_noise = mean + std * reparametrized_noise
self.debug("reparametrization", reparametrized_noise)
logger.debug("Done.")
return mean, logvar, reparametrized_noise
[docs] def debug(self, name, tensor):
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]@Networks.register
class BGCodeDiscriminator(nn.Module):
""" This is the code discriminator part of the BGGAN.
"""
[docs] def __init__(self, out_channels=1, code_size=1000, n_units=4096):
""" Init class.
Parameters
----------
out_channels: int, default 1
number of channels in the output tensor.
code_size: int, default 1000
the code sier.
n_units: int, default 4096
the number of hidden units.
"""
super(BGCodeDiscriminator, self).__init__()
self.out_channels = out_channels
self.code_size = code_size
self.n_units = n_units
self.layer1 = nn.Sequential(
nn.Linear(self.code_size, self.n_units),
nn.BatchNorm1d(self.n_units),
nn.LeakyReLU(0.2, inplace=True))
self.layer2 = nn.Sequential(
nn.Linear(self.n_units, self.n_units),
nn.BatchNorm1d(self.n_units),
nn.LeakyReLU(0.2, inplace=True))
self.layer3 = nn.Linear(self.n_units, self.out_channels)
[docs] def forward(self, x):
logger.debug("BGGAN Code Discriminator...")
self.debug("input", x)
h1 = self.layer1(x)
self.debug("layer1", h1)
h2 = self.layer2(h1)
self.debug("layer2", h2)
output = self.layer3(h2)
self.debug("layer3", output)
logger.debug("Done.")
return output
[docs] def debug(self, name, tensor):
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]@Networks.register
class BGGenerator(nn.Module):
""" This is the generator part of the BGGAN.
"""
[docs] def __init__(self, in_shape, out_channels=1, start_filts=64,
latent_dim=1000, mode="trilinear", with_code=False):
""" Init class.
Parameters
----------
in_shape: uplet
the input tensor data shape (X, Y, Z).
out_channels: int, default 1
number of channels in the output tensor.
start_filts: int, default 64
number of convolutional filters for the first conv.
latent_dim: int, default 1000
the latent variable sizes.
mode: str, default 'trilinear'
the interpolation mode.
with_code: bool, default False
change the architecture if code discriminator is used.
"""
super(BGGenerator, self).__init__()
self.out_channels = out_channels
self.start_filts = start_filts
self.latent_dim = latent_dim
self.in_shape = in_shape
self.mode = mode
self.with_code = with_code
self.shapes = _downsample_shape(
self.in_shape, nb_iterations=4, scale_factor=2)
self.dense_features = np.prod(self.shapes[-1])
logger.debug("BGGAN Generator shapes: {0}".format(self.shapes))
if self.with_code:
self.tp_conv1 = nn.ConvTranspose3d(
self.latent_dim, self.start_filts * 8, kernel_size=4,
stride=1, padding=0, bias=False)
else:
self.fc = nn.Linear(
self.latent_dim, self.start_filts * 8 * self.dense_features)
self.bn1 = nn.BatchNorm3d(self.start_filts * 8)
self.tp_conv2 = nn.Conv3d(
self.start_filts * 8, self.start_filts * 4, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(self.start_filts * 4)
self.tp_conv3 = nn.Conv3d(
self.start_filts * 4, self.start_filts * 2, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm3d(self.start_filts * 2)
self.tp_conv4 = nn.Conv3d(
self.start_filts * 2, self.start_filts, kernel_size=3, stride=1,
padding=1, bias=False)
self.bn4 = nn.BatchNorm3d(self.start_filts)
self.tp_conv5 = nn.Conv3d(
self.start_filts, self.out_channels, kernel_size=3, stride=1,
padding=1, bias=False)
[docs] def forward(self, noise):
logger.debug("BGGAN Generator...")
self.debug("input", noise)
if self.with_code:
noise = noise.view(-1, self.latent_dim, 1, 1, 1)
self.debug("view", noise)
h = self.tp_conv1(noise)
self.debug("tp_conv1", h)
else:
noise = noise.view(-1, self.latent_dim)
self.debug("view", noise)
h = self.fc(noise)
self.debug("dense", h)
h = h.view(-1, self.start_filts * 8, *self.shapes[-1])
self.debug("view", h)
h = func.relu(self.bn1(h))
h = nn.functional.interpolate(
h, size=self.shapes[-2], mode=self.mode, align_corners=False)
h = self.tp_conv2(h)
h = func.relu(self.bn2(h))
self.debug("tp_conv2", h)
h = nn.functional.interpolate(
h, size=self.shapes[-3], mode=self.mode, align_corners=False)
h = self.tp_conv3(h)
h = func.relu(self.bn3(h))
self.debug("tp_conv3", h)
h = nn.functional.interpolate(
h, size=self.shapes[-4], mode=self.mode, align_corners=False)
h = self.tp_conv4(h)
h = func.relu(self.bn4(h))
self.debug("tp_conv4", h)
h = nn.functional.interpolate(
h, size=self.shapes[-5], mode=self.mode, align_corners=False)
h = self.tp_conv5(h)
self.debug("tp_conv5", h)
h = torch.tanh(h)
self.debug("output", h)
logger.debug("Done.")
return h
[docs] def debug(self, name, tensor):
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
def _downsample_shape(shape, nb_iterations=1, scale_factor=2):
shape = np.asarray(shape)
all_shapes = [shape.astype(int).tolist()]
for idx in range(nb_iterations):
shape = np.floor(shape / scale_factor)
all_shapes.append(shape.astype(int).tolist())
return all_shapes
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.