Helper Module for Deep Learning.
Recursive Cascaded Networks (RCNet) for Unsupervised Medical Image Registration using and Dense Deformable Network (ADDNet) and Volume Tweening Network (VTN).
-
class
pynet.models.rcnet.RCNet(input_shape, in_channels, base_network, n_cascades=1, rep=1)[source]¶ RCnet.
The recursive cascaded networks is a general architecture that enables learning deep cascades and can be used for deformable image registration. The cascade architecture is simple in design and can be built on any base network. The moving image is warped successively by each cascade and finally aligned to the fixed image; this procedure is recursive in a way that every cascade learns to perform a progressive deformation for the current warped image. The entire system is end-to-end and jointly trained in an unsupervised manner. Shared-weight techniques are developed in addition to the recursive architecture. Shared-weight cascading in training is not used since it consumes extra GPU memory.
We use the Dense Deformable Network (ADDNet) to estimate the affine transform in combination with a deformation field network estimator.
This network achieves state-of-the-art performance on both liver CT and brain MRI datasets for 3D medical image registration.
Reference: * https://arxiv.org/pdf/1907.12353 * https://arxiv.org/pdf/1902.05020
Code: * https://github.com/microsoft/Recursive-Cascaded-Networks.
-
__init__(input_shape, in_channels, base_network, n_cascades=1, rep=1)[source]¶ Init class.
- Parameters
input_shape: uplet
the tensor data shape (X, Y, Z).
in_channels: int
number of channels in the input tensor.
base_network: str
the name of the Network used to estimate the non-linear deformation.
n_cascades: int, default 1
the number of cascades.
rep: int, default 1
the number of times of shared-weight cascading.
-
default_params= {'raw_weight': 1.0, 'reg_weight': 1.0, 'weight': 1.0}¶
-
forward(x)[source]¶ Forward method.
- Parameters
x: Tensor
concatenated moving and fixed images (batch, 2 * channels, X, Y, Z)
-
property
trainable_parameters¶ Get the number of trainable parameters.
-
Follow us
Inspired by AZMIND template.