Menu

Helper Module for Deep Learning.

Source code for pynet.models.rcnet

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2019 - 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Recursive Cascaded Networks (RCNet) for Unsupervised Medical Image
Registration using and Dense Deformable Network (ADDNet) and Volume Tweening
Network (VTN).
"""

# Imports
import logging
import collections
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as func
from torch.autograd import Variable
from pynet.observable import SignalObject
from .vtnet import ADDNetRegularizer
from .voxelmorphnet import FlowRegularizer
from .voxelmorphnet import SpatialTransformer
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks
from pynet.utils import Regularizers
from pynet.utils import get_tools


# Global parameters
Stem = namedtuple("Stem", ["network", "params"])
logger = logging.getLogger("pynet")


[docs]@Networks.register @DeepLearningDecorator(family="register") class RCNet(nn.Module): """ RCnet. The recursive cascaded networks is a general architecture that enables learning deep cascades and can be used for deformable image registration. The cascade architecture is simple in design and can be built on any base network. The moving image is warped successively by each cascade and finally aligned to the fixed image; this procedure is recursive in a way that every cascade learns to perform a progressive deformation for the current warped image. The entire system is end-to-end and jointly trained in an unsupervised manner. Shared-weight techniques are developed in addition to the recursive architecture. Shared-weight cascading in training is not used since it consumes extra GPU memory. We use the Dense Deformable Network (ADDNet) to estimate the affine transform in combination with a deformation field network estimator. This network achieves state-of-the-art performance on both liver CT and brain MRI datasets for 3D medical image registration. Reference: * https://arxiv.org/pdf/1907.12353 * https://arxiv.org/pdf/1902.05020 Code: * https://github.com/microsoft/Recursive-Cascaded-Networks. """ default_params = { "weight": 1., "raw_weight": 1., "reg_weight": 1.}
[docs] def __init__(self, input_shape, in_channels, base_network, n_cascades=1, rep=1): """ Init class. Parameters ---------- input_shape: uplet the tensor data shape (X, Y, Z). in_channels: int number of channels in the input tensor. base_network: str the name of the Network used to estimate the non-linear deformation. n_cascades: int, default 1 the number of cascades. rep: int, default 1 the number of times of shared-weight cascading. """ # Inheritance logger.debug("RCNet configuration...") nn.Module.__init__(self) # Class parameters available_networks = get_tools()["networks"] if base_network not in available_networks: raise ValueError( "Unknown base network '{0}', available networks are " "{1}.".format(base_network, available_networks.keys())) self.base_network = available_networks[base_network] logger.debug(" base network: {0}".format(self.base_network)) self.stems = [Stem( network=available_networks["ADDNet"]( input_shape=input_shape, in_channels=in_channels, flow_multiplier=1.), params={"raw_weight": 0, "reg_weight": 0})] self.stems += [Stem( network=self.base_network( input_shape=input_shape, in_channels=in_channels, flow_multiplier=(1. / n_cascades)), params={"raw_weight": 0})] * (rep * n_cascades) self.stems[-1].params["raw_weight"] = 1 for stem in self.stems: for key, val in self.default_params.items(): if key not in stem.params: stem.params[key] = val logger.debug(" stems: {0}".format(self.stems)) # Finally warp the moving image: avoid accumulation of interpolation # errors, ie. reinterpolate after each cascade. self.spatial_transform = SpatialTransformer(input_shape)
[docs] def parameters(self): """ Get the trainable variables. """ return list(set( sum([list(stem.network.parameters()) for stem in self.stems], [])))
@property def trainable_parameters(self): """ Get the number of trainable parameters. """ nb_params = 0 for stem in self.stems: nb_params += sum( params.numel() for params in stem.network.parameters()) return nb_params
[docs] def forward(self, x): """ Forward method. Parameters ---------- x: Tensor concatenated moving and fixed images (batch, 2 * channels, X, Y, Z) """ logger.debug("RCNET...") device = x.device for stem in self.stems: if next(stem.network.parameters()).device != device: stem.network.to(device) stem_results = [] nb_channels = x.shape[1] // 2 moving = x[:, :nb_channels] fixed = x[:, nb_channels:] warp, stem_result = self.stems[0].network( torch.cat((moving, fixed), dim=1)) stem_result["warped"] = warp stem_result["agg_flow"] = stem_result["flow"] stem_result["stem_params"] = self.stems[0].params stem_results.append(stem_result) for stem in self.stems[1:]: warp, stem_result = stem.network( torch.cat((stem_results[-1]["warped"], fixed), dim=1)) stem_result["stem_params"] = stem.params stem_result["agg_flow"] = ( stem_results[-1]["agg_flow"] + stem_result["flow"]) stem_result["warped"], _ = self.spatial_transform( moving, stem_result["agg_flow"]) stem_results.append(stem_result) flow = stem_results[-1]["agg_flow"] warp = stem_results[-1]["warped"] jacobian_det = 0 # self.jacobian_det(flow) return warp, {"flow": flow, "stem_results": stem_results, "jacobian_det": jacobian_det}
[docs] def jacobian_det(self, flow): """ Compute the Jacobian determinant of displacement field. """ # Compute Jacobian row by row. jac = [ flow[:, :, 1:, :-1, :-1] - flow[:, :, :-1, :-1, :-1] + Variable(torch.tensor( [1, 0, 0], dtype=torch.float32), requires_grad=False), flow[:, :, :-1, 1:, :-1] - flow[:, :, :-1, :-1, :-1] + Variable(torch.tensor( [0, 1, 0], dtype=torch.float32), requires_grad=False), flow[:, :, :-1, :-1, 1:] - flow[:, :, :-1, :-1, :-1] + Variable(torch.tensor( [0, 0, 1], dtype=torch.float32), requires_grad=False) ] jac = torch.stack(jac, dim=1) # Take the determinant of the Jacobian var = torch.std(torch.det(jac), dim=(2, 3, 4)) return torch.sqrt(var)
[docs]@Regularizers.register class RCNetRegularizer(object): """ RCNet Regularization. ADDNetRegularizer + FlowRegularizer. """
[docs] def __init__(self, det_factor=0.1, ortho_factor=0.1, reg_factor=1.0): self.addnet_reg = ADDNetRegularizer(k1=det_factor, k2=ortho_factor) self.flow_reg = FlowRegularizer(k1=reg_factor)
def __call__(self, signal): logger.debug("Compute RCNet regularization...") stem_results = signal.layer_outputs["stem_results"] for stem_result in stem_results: params = stem_results["stem_params"] sub_signal = SignalObject() setattr(sub_signal, "layer_outputs", stem_result) if "W" in stem_result: stem_result["loss"] = self.addnet_reg(signal) else: if params["reg_weight"] > 0: flow_loss = self.flow_reg(signal) stem_result["loss"] = flow_loss * params["reg_weight"] loss = sum([ stem_result["loss"] * stem_results["stem_params"]["weight"] for stem_result in stem_results]) logger.debug("Done.") return loss

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.