Helper Module for Deep Learning.
Source code for pynet.models.vtnet
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Volume Tweening Network (VTN) and Affine and Dense Deformable Network (ADDNet)
for Unsupervised medical Image Registration.
"""
# Imports
import logging
import collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
from .voxelmorphnet import SpatialTransformer
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks
from pynet.utils import Regularizers
# Global parameters
logger = logging.getLogger("pynet")
[docs]@Networks.register
@DeepLearningDecorator(family="register")
class VTNet(nn.Module):
"""
VTNet.
Volume Tweening Network(VTN) consists of several cascaded registration
subnetworks, after each of which the moving image is warped. The
unsupervised training of network parameters is guided by the dissimilarity
between the fixed image and each of the warped images, with the
regularization losses on the flows predicted by the networks.
It follows an encoder-decoder architecture.
Reference: https://arxiv.org/pdf/1902.05020.
Code: https://github.com/microsoft/Recursive-Cascaded-Networks.
"""
[docs] def __init__(self, input_shape, in_channels, kernel_size=3, padding=1,
flow_multiplier=1., nb_channels=16):
""" Init class.
Parameters
----------
input_shape: uplet
the tensor data shape (X, Y, Z).
in_channels: int
number of channels in the input tensor.
kernel_size: int, default 3
the convolution kernels size (odd number).
padding: int, default 1
the padding size, recommended (kernel_size - 1) / 2
flow_multiplier: foat, default 1
weight the flow field by this factor.
nb_channels: int, default 16
the number of channels after the first convolution.
"""
# Inheritance
nn.Module.__init__(self)
# Class parameters
self.input_shape = input_shape
self.in_channels = in_channels
self.kernel_size = kernel_size
self.padding = padding
self.flow_multiplier = flow_multiplier
self.shapes = self._downsample_shape(
input_shape, nb_iterations=6, scale_factor=2)
self.nb_channels = nb_channels
# Use strided 3D convolution to progressively downsample the image,
# and then use deconvolution (transposed convolution) to recover
# spatial resolution. As suggested in U-Net, skip connections between
# the convolutional layers and the deconvolutional layers are added
# to help refining dense prediction. The network will output the dense
# flow field, a volume feature map with 3 channels (X, Y, Z
# displacements) of the same size as the input.
out_channels = nb_channels
for idx in range(1, 3):
ops = self._conv(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=True,
negative_slope=0.1)
setattr(self, "down{0}".format(idx), ops)
in_channels = out_channels
out_channels *= 2
for idx in range(3, 7):
ops = self._double_conv(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=True,
negative_slope=0.1)
setattr(self, "down{0}".format(idx), ops)
in_channels = out_channels
out_channels *= 2
out_channels = in_channels // 2
for idx in range(5, 0, -1):
if idx < 5:
in_channels = in_channels * 2 + 3
pred_ops = self._prediction(in_channels=in_channels)
ops = self._upconv(
in_channels=in_channels, out_channels=out_channels,
kernel_size=4, stride=2, padding=2, groups=1,
negative_slope=0.1)
setattr(self, "pred{0}".format(idx + 1), pred_ops)
setattr(self, "up{0}".format(idx), ops)
in_channels = out_channels
out_channels = out_channels // 2
in_channels = in_channels * 2 + 3
self.pred1 = nn.ConvTranspose3d(
in_channels=in_channels, out_channels=3, kernel_size=4,
stride=2, padding=1, groups=1)
# Finally warp the moving image.
self.spatial_transform = SpatialTransformer(input_shape)
# Init weights
@torch.no_grad()
def weights_init(module):
if isinstance(module, nn.Conv3d):
logger.debug("Init weights of {0}...".format(module))
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
self.apply(weights_init)
def _conv(self, in_channels, out_channels, kernel_size, stride=1,
padding=1, bias=True, negative_slope=1e-2):
ops = nn.Sequential(collections.OrderedDict([
("conv", nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=bias)),
("act", nn.LeakyReLU(negative_slope=negative_slope))
]))
return ops
def _double_conv(self, in_channels, out_channels, kernel_size, stride=2,
padding=1, bias=True, negative_slope=1e-2):
ops = nn.Sequential(collections.OrderedDict([
("conv1", nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=bias)),
("act1", nn.LeakyReLU(negative_slope=negative_slope)),
("conv2", nn.Conv3d(out_channels, out_channels, kernel_size,
stride=1, padding=padding, bias=bias)),
("act2", nn.LeakyReLU(negative_slope=negative_slope))
]))
return ops
def _upconv(self, in_channels, out_channels, kernel_size, stride=2,
padding=1, groups=1, negative_slope=1e-2):
ops = nn.Sequential(collections.OrderedDict([
("convt", nn.ConvTranspose3d(
in_channels, out_channels, kernel_size, stride=stride,
padding=1, groups=groups)),
("act", nn.LeakyReLU(negative_slope=negative_slope))
]))
return ops
def _prediction(self, in_channels):
ops = nn.Sequential(collections.OrderedDict([
("conv", nn.Conv3d(
in_channels, out_channels=3, kernel_size=3,
stride=1, padding=1, bias=True)),
("convt", nn.ConvTranspose3d(
in_channels=3, out_channels=3, kernel_size=4, stride=2,
padding=1, groups=1))
]))
return ops
def _downsample_shape(self, shape, nb_iterations=1, scale_factor=2):
shape = np.asarray(shape)
all_shapes = [shape.astype(int).tolist()]
for idx in range(nb_iterations):
shape = np.ceil(shape / scale_factor)
all_shapes.append(shape.astype(int).tolist())
return all_shapes
[docs] def forward(self, x):
""" Forward method.
Parameters
----------
x: Tensor
concatenated moving and fixed images (batch, 2 * channels, X, Y, Z)
"""
logger.debug("VTNet...")
nb_channels = x.shape[1] // 2
device = x.get_device()
logger.debug(" nb_channels: {0}".format(nb_channels))
self.debug("input", x)
moving = x[:, :nb_channels]
self.debug("moving", moving)
skipx = []
for idx in range(1, 7):
logger.debug("Applying down{0}...".format(idx))
self.debug("input", x)
layer = getattr(self, "down{0}".format(idx))
logger.debug(" filter: {0}".format(layer))
x = layer(x)
skipx.append(x)
self.debug("output", x)
logger.debug("Done.")
for idx in range(5, 0, -1):
logger.debug("Applying up{0}...".format(idx))
self.debug("input", x)
layer = getattr(self, "up{0}".format(idx))
pred_layer = getattr(self, "pred{0}".format(idx + 1))
logger.debug(" filter: {0}".format(layer))
logger.debug(" pred filter: {0}".format(pred_layer))
flow_pred = pred_layer(x)
self.debug("flow prediction", flow_pred)
x = layer(x)
self.debug("layer output", x)
self.debug("skip connexion", skipx[idx - 1])
x = torch.cat((skipx[idx - 1], x, flow_pred), dim=1)
self.debug("output", x)
logger.debug("Done.")
logger.debug("Estimating flow field...")
logger.debug(" pred filter: {0}".format(self.pred1))
flow = self.pred1(x)
self.debug("flow", flow)
logger.debug("Done.")
logger.debug("Applying warp...")
self.debug("moving", moving)
warp, _ = self.spatial_transform(moving, flow)
self.debug("warp", warp)
logger.debug("Done.")
logger.debug("Done.")
return warp, {"flow": flow * self.flow_multiplier}
[docs] def debug(self, name, tensor):
""" Print debug message.
Parameters
----------
name: str
the tensor name in the displayed message.
tensor: Tensor
a pytorch tensor.
"""
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]@Networks.register
@DeepLearningDecorator(family="register")
class ADDNet(nn.Module):
"""
ADDNet.
Affine and Dense Deformable Network (ADDNet): affine registration
subnetwork predicts a set of affine parameters, after which a flow field
can be generated for warping.
Reference: https://arxiv.org/pdf/1902.05020.
Code: https://github.com/microsoft/Recursive-Cascaded-Networks.
"""
[docs] def __init__(self, input_shape, in_channels, kernel_size=3, padding=1,
flow_multiplier=1.):
""" Init class.
Parameters
----------
input_shape: uplet
the tensor data shape (X, Y, Z).
in_channels: int
number of channels in the input tensor.
kernel_size: int, default 3
the convolution kernels size (odd number).
padding: int, default 1
the padding size, recommended (kernel_size - 1) / 2
flow_multiplier: foat, default 1
weight the flow field by this factor.
"""
# Inheritance
nn.Module.__init__(self)
# Class parameters
self.input_shape = input_shape
self.in_channels = in_channels
self.kernel_size = kernel_size
self.padding = padding
self.flow_multiplier = flow_multiplier
self.shapes = self._downsample_shape(
input_shape, nb_iterations=6, scale_factor=2)
self.dense_features = np.prod(self.shapes[-1])
# The input is downsampled by strided 3D convolutions, and finally a
# fully-connected layer is applied to produce 12 numeric parameters
# as output, which represents a 3×3 transform matrix.
out_channels = 16
for idx in range(1, 3):
ops = self._conv(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=True,
negative_slope=0.1)
setattr(self, "layer{0}".format(idx), ops)
in_channels = out_channels
out_channels *= 2
for idx in range(3, 7):
ops = self._double_conv(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=2, padding=1, bias=True,
negative_slope=0.1)
setattr(self, "layer{0}".format(idx), ops)
in_channels = out_channels
out_channels *= 2
self.linear1 = torch.nn.Linear(in_channels * self.dense_features, 9,
bias=False)
self.linear2 = torch.nn.Linear(in_channels * self.dense_features, 3,
bias=False)
# Finally warp the moving image.
self.spatial_transform = SpatialTransformer(input_shape)
# Init weights
@torch.no_grad()
def weights_init(module):
if isinstance(module, nn.Conv3d):
logger.debug("Init weights of {0}...".format(module))
torch.nn.init.xavier_uniform_(module.weight)
torch.nn.init.constant_(module.bias, 0)
self.apply(weights_init)
def _conv(self, in_channels, out_channels, kernel_size, stride=1,
padding=1, bias=True, negative_slope=1e-2):
ops = nn.Sequential(collections.OrderedDict([
("conv", nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=bias)),
("act", nn.LeakyReLU(negative_slope=negative_slope))
]))
return ops
def _double_conv(self, in_channels, out_channels, kernel_size, stride=2,
padding=1, bias=True, negative_slope=1e-2):
ops = nn.Sequential(collections.OrderedDict([
("conv1", nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=bias)),
("act1", nn.LeakyReLU(negative_slope=negative_slope)),
("conv2", nn.Conv3d(out_channels, out_channels, kernel_size,
stride=1, padding=padding, bias=bias)),
("act2", nn.LeakyReLU(negative_slope=negative_slope))
]))
return ops
def _downsample_shape(self, shape, nb_iterations=1, scale_factor=2):
shape = np.asarray(shape)
all_shapes = [shape.astype(int).tolist()]
for idx in range(nb_iterations):
shape = np.ceil(shape / scale_factor)
all_shapes.append(shape.astype(int).tolist())
return all_shapes
[docs] def affine_flow(self, affine, size, without_identity=False):
""" Generates a flow field given an affine matrix.
Parameters
----------
affine: Tensor (N, 4, 4)
an affine transform.
size tuple
the target output image size.
without_identity: bool, defaul False
set to true if the identity matrix has already been substrated to
the affine matrix.
Returns
-------
flow: Tensor
the generated affine flow field.
"""
if not isinstance(size, list):
size = list(size)
device = affine.device
if not without_identity:
mat_id = torch.eye(4)
mat_id = mat_id.view(1, 4, 4)
if device != -1:
mat_id = mat_id.to(device)
affine = affine - mat_id
n_batch = affine.size(0)
vectors = [torch.arange(0, val) for val in size]
for indx in range(len(size)):
vectors[indx] = vectors[indx] - (size[indx] - 1) // 2
grids = torch.meshgrid(vectors)
grid = torch.stack(grids)
grid = grid.type(torch.FloatTensor)
ones = torch.ones([1] + size, dtype=grid.dtype)
homography_grid = torch.cat([grid, ones], dim=0)
homography_grid = homography_grid.view(1, 4, 1, -1).permute(3, 0, 1, 2)
if device != -1:
homography_grid = homography_grid.to(device)
self.debug("grid", homography_grid)
affine = affine.view(1, n_batch, 4, 4)
self.debug("affine", affine)
flow = torch.matmul(affine, homography_grid)
flow = flow.permute(3, 1, 2, 0)
flow = flow.view([n_batch, 4] + size)
return flow[:, :3]
[docs] def forward(self, x):
""" Forward method.
y = Ax + b
the model learns W = A - I
Parameters
----------
x: Tensor
concatenated moving and fixed images (batch, 2 * channels, X, Y, Z)
"""
logger.debug("ADDNet...")
nb_channels = x.shape[1] // 2
device = x.get_device()
logger.debug(" nb_channels: {0}".format(nb_channels))
self.debug("input", x)
moving = x[:, :nb_channels]
self.debug("moving", moving)
for idx in range(1, 7):
logger.debug("Applying layer{0}...".format(idx))
self.debug("input", x)
layer = getattr(self, "layer{0}".format(idx))
logger.debug(" filter: {0}".format(layer))
x = layer(x)
self.debug("output", x)
logger.debug("Flatening...")
self.debug("input", x)
logger.debug(" dense features: {0}".format(self.dense_features))
x = x.view(-1, 512 * self.dense_features)
self.debug("output", x)
logger.debug("Getting W...")
vec_w = self.linear1(x)
self.debug("W", vec_w)
logger.debug("Getting b...")
vec_b = self.linear2(x)
self.debug("b", vec_b)
logger.debug("Getting A...")
# the flow is displacement(x) = place(x) - x = (Ax + b) - x
# the model learns W = A - I.
mat_id = torch.eye(3)
if device != -1:
mat_id = mat_id.to(device)
mat_id = mat_id.view(1, 3, 3)
mat_w = vec_w.view(-1, 3, 3) * self.flow_multiplier
vec_b = vec_b * self.flow_multiplier
mat_a = mat_w + mat_id
self.debug("A", mat_a)
logger.debug("Getting flow...")
self.debug("b", vec_b)
vec_b = vec_b.view(-1, 3, 1)
# affine = torch.cat((mat_a, vec_b), dim=2)
# theta = to_homography(affine)
# norm_theta = normalize_homography(
# theta, shape_src=self.input_shape, shape_dst=self.input_shape)
# theta = norm_theta[:, :3, :]
# self.debug("theta", theta)
# size = [mat_a.size(0), 1] + list(self.input_shape)
# logger.debug(" size: {0}".format(size))
# flow = func.affine_grid(theta, size, align_corners=False)
affine = to_homography(torch.cat((mat_w, vec_b), dim=2))
self.debug("affine", affine)
flow = self.affine_flow(
affine, self.input_shape, without_identity=True)
self.debug("flow", flow)
logger.debug("Applying warp...")
self.debug("moving", moving)
# warp = func.grid_sample(moving, flow, align_corners=False)
warp, _ = self.spatial_transform(moving, flow)
self.debug("warp", warp)
logger.debug("Done.")
return warp, {"flow": flow, "A": mat_a, "b": vec_b, "W": mat_w}
[docs] def debug(self, name, tensor):
""" Print debug message.
Parameters
----------
name: str
the tensor name in the displayed message.
tensor: Tensor
a pytorch tensor.
"""
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]def to_homography(batch_affine):
""" Convert batch of affine matrices of size (N, 3, 4) to (N, 4, 4).
"""
affine = func.pad(batch_affine, [0, 0, 0, 1], "constant", value=0.)
affine[..., -1, -1] += 1.0
return affine
[docs]def normal_transform_pixel(shape):
""" Compute the normalization matrix from image size in pixels to [-1, 1].
"""
tr_mat = torch.tensor([[1.0, 0.0, 0.0, -1.0],
[0.0, 1.0, 0.0, -1.0],
[0.0, 0.0, 1.0, -1.0],
[0.0, 0.0, 0.0, 1.0]])
for idx in range(len(shape)):
tr_mat[idx, idx] = tr_mat[idx, idx] * 2.0 / (shape[idx] - 1.0)
tr_mat = tr_mat.unsqueeze(0)
return tr_mat
[docs]def normalize_homography(affine, shape_src, shape_dst):
""" Normalize a given homography in pixels to [-1, 1].
Reference: https://discuss.pytorch.org/t/
affine-transformation-matrix-paramters-conversion/19522/13
Parameters
----------
affine: torch.Tensor (N, 4, 4)
homography/ies from source to destiantion to be normalized.
shape_src: tuple (3,)
size of the source image.
shape_dst: tuple (3,)
size of the destination image.
Returns
-------
affine: torch.Tensor (N, 4, 4)
the normalized homography/ies.
"""
if not torch.is_tensor(affine):
raise TypeError("Input affine type is not a torch.Tensor.")
if not (len(affine.shape) == 3 or affine.shape[-2:] == (4, 4)):
raise ValueError("Input affine must be a Nx4x4.")
# Parameters
device = affine.device
dtype = affine.dtype
# Compute the transformation pixel/norm for src/dst
src_norm_trf_src_pix = normal_transform_pixel(shape_src).to(device, dtype)
src_pix_trf_src_norm = torch.inverse(src_norm_trf_src_pix)
dst_norm_trf_dst_pix = normal_transform_pixel(shape_dst).to(device, dtype)
# Compute chain transformations
dst_norm_trans_src_norm = (
dst_norm_trf_dst_pix @ (affine @ src_pix_trf_src_norm))
return dst_norm_trans_src_norm
[docs]@Regularizers.register
class ADDNetRegularizer(object):
""" ADDNet Combined Regularization.
In addition to the correlation coefficient as our similarity loss, the
orthogonality loss and the determinant loss are used as regularization
losses for the affine network.
k1 * DetLoss + k2 * OrthoLoss
DetLoss: determinant should be close to 1, ie. reflection is not allowed.
OrthoLoss: should be close to being orthogonal, ie. penalize the network
for producing overly non-rigid transform.
Let C=A'A, a positive semi-definite matrix should be close to I.
For this, we require C has eigen values close to 1 by minimizing
k1 + 1/k1 + k2 + 1/k2 + k3 + 1/k3. To prevent NaN, minimize
k1 + eps + (1+eps)^2 / (k1+eps) + ...
"""
[docs] def __init__(self, k1=0.1, k2=0.1, eps=1e-5):
self.k1 = k1
self.k2 = k2
self.eps = eps
self.det_loss = 0
self.ortho_loss = 0
def __call__(self, signal):
logger.debug("ADDNetRegularizer...")
mat_a = signal.layer_outputs["A"]
self.debug("A", mat_a)
device = mat_a.get_device()
det = mat_a.det()
self.debug("determinant", det)
self.det_loss = torch.norm(det - 1., 2)
logger.debug(" determinant loss: {0}".format(self.det_loss))
mat_eps = torch.eye(3)
if device != -1:
mat_eps = mat_eps.to(device)
mat_eps *= self.eps
mat_eps = mat_eps.view(1, 3, 3)
self.debug("eps", mat_eps)
mat_c = torch.bmm(mat_a.permute(0, 2, 1), mat_a) + mat_eps
self.debug("C", mat_c)
def elem_sym_polys_of_eigen_values(mat):
mat = [[mat[:, idx_i, idx_j]
for idx_j in range(3)] for idx_i in range(3)]
sigma1 = (mat[0][0] + mat[1][1] + mat[2][2])
sigma2 = (mat[0][0] * mat[1][1] +
mat[1][1] * mat[2][2] +
mat[2][2] * mat[0][0]) - (
mat[0][1] * mat[1][0] +
mat[1][2] * mat[2][1] +
mat[2][0] * mat[0][2])
sigma3 = (mat[0][0] * mat[1][1] * mat[2][2] +
mat[0][1] * mat[1][2] * mat[2][0] +
mat[0][2] * mat[1][0] * mat[2][1]) - (
mat[0][0] * mat[1][2] * mat[2][1] +
mat[0][1] * mat[1][0] * mat[2][2] +
mat[0][2] * mat[1][1] * mat[2][0])
return sigma1, sigma2, sigma3
s1, s2, s3 = elem_sym_polys_of_eigen_values(mat_c)
self.debug("s1", s1)
self.debug("s2", s2)
self.debug("s3", s3)
eps = self.eps
ortho_loss = s1 + (1 + eps) * (1 + eps) * s2 / s3 - 3 * 2 * (1 + eps)
self.debug("orthogonal", ortho_loss)
self.ortho_loss = self.k2 * torch.sum(ortho_loss)
logger.debug(" orthogonal loss: {0}".format(self.ortho_loss))
logger.debug("Done.")
return self.k1 * self.det_loss + self.k2 * self.ortho_loss
[docs] def debug(self, name, tensor):
""" Print debug message.
Parameters
----------
name: str
the tensor name in the displayed message.
tensor: Tensor
a pytorch tensor.
"""
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.