Menu

Helper Module for Deep Learning.

Source code for pynet.models.sononet

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2020
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Sononet is a CNN architecture with two components: a feature extractor module
and an adaptation module.
"""

# Imports
import logging
import collections
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as func
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks
import numpy as np


# Global parameters
logger = logging.getLogger("pynet")


[docs]@Networks.register @DeepLearningDecorator(family=("classifier")) class SonoNet(nn.Module): """ SonoNet. Feature extraction: the first 17 layers (counting max-pooling) of the VGG network is used to extract discriminant features (3 layers for the first 3 and 2 layers for the last 2 feature scales). Note that the number of filters are doubled after each of the first three max-pooling operations. Attention maps (adaptation module): the number of channels are first reduced to the number of target classes C. Subsequently, the spatial information is flattened via channel-wise global average pooling. Finally, a soft-max operation is applied to the resulting vector and the entry with maximum activation is selected as the prediction. As the network is constrained to classify based on the reduced vector, the network is forced to extract the most salient features for each class. Reference: Attention-Gated Networksfor Improving Ultrasound Scan Plane Detection https://arxiv.org/pdf/1804.05338.pdf Code: https://github.com/ozan-oktay/Attention-Gated-Networks """
[docs] def __init__(self, n_classes, in_channels=1, n_convs=[3, 3, 3, 2, 2], start_filts=64, batchnorm=True, nonlocal_mode="concatenation", aggregation_mode="concat"): """ Init class. Parameters ---------- n_classes: int the number of features in the output segmentation map. in_channels: int, default 1 number of channels in the input tensor. n_convs: list of int, default [3, 3, 3, 2, 2] the number of convolutions start_filts: int, default 64 number of convolutional filters for the first conv. batchnorm: bool, default False normalize the inputs of the activation function. nonlocal_mode: str, default 'concatenation' aggregation_mode: str, default 'concat' """ # Inheritance nn.Module.__init__(self) # Parameters self.n_classes = n_classes self.in_channels = in_channels self.n_convs = n_convs self.start_filts = start_filts self.batchnorm = batchnorm self.nonlocal_mode = nonlocal_mode self.aggregation_mode = aggregation_mode # Feature Extraction filters = [start_filts * cnt for cnt in range(1, len(self.n_convs))] self.conv1 = Conv2( self.in_channels, filters[0], self.batchnorm, n=n_convs[0]) self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = Conv2( filters[0], filters[1], self.batchnorm, n=n_convs[1]) self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = Conv2( filters[1], filters[2], self.batchnorm, n=n_convs[2]) self.maxpool3 = nn.MaxPool2d(kernel_size=2) self.conv4 = Conv2( filters[2], filters[3], self.batchnorm, n=n_convs[3]) self.maxpool4 = nn.MaxPool2d(kernel_size=2) self.conv5 = Conv2( filters[3], filters[3], self.batchnorm, n=n_convs[4]) # Attention Maps self.compatibility_score1 = GridAttentionBlock2D( in_channels=filters[2], gating_channels=filters[3], inter_channels=filters[3], sub_sample_factor=(1, 1), mode=nonlocal_mode, use_W=False, use_phi=True, use_theta=True, use_psi=True, nonlinearity1="relu") self.compatibility_score2 = GridAttentionBlock2D( in_channels=filters[3], gating_channels=filters[3], inter_channels=filters[3], sub_sample_factor=(1, 1), mode=nonlocal_mode, use_W=False, use_phi=True, use_theta=True, use_psi=True, nonlinearity1="relu") # Aggreagation Strategies self.attention_filter_sizes = [filters[2], filters[3]] if aggregation_mode == "concat": self.classifier = nn.Linear(filters[2] + filters[3] + filters[3], n_classes) self.aggregate = self.aggregation_concat else: self.classifier1 = nn.Linear(filters[2], n_classes) self.classifier2 = nn.Linear(filters[3], n_classes) self.classifier3 = nn.Linear(filters[3], n_classes) self.classifiers = [self.classifier1, self.classifier2, self.classifier3] if aggregation_mode == "mean": self.aggregate = self.aggregation_sep elif aggregation_mode == "deep_sup": self.classifier = nn.Linear( filters[2] + filters[3] + filters[3], n_classes) self.aggregate = self.aggregation_ds elif aggregation_mode == "ft": self.classifier = nn.Linear(n_classes * 3, n_classes) self.aggregate = self.aggregation_ft else: raise NotImplementedError # Initialise weights for m in self.modules(): if isinstance(m, nn.Conv2d): init_weights(m, init_type="kaiming") elif isinstance(m, nn.BatchNorm2d): init_weights(m, init_type="kaiming")
[docs] def aggregation_sep(self, *attended_maps): return [clf(att) for clf, att in zip(self.classifiers, attended_maps)]
[docs] def aggregation_ft(self, *attended_maps): preds = self.aggregation_sep(*attended_maps) return self.classifier(torch.cat(preds, dim=1))
[docs] def aggregation_ds(self, *attended_maps): preds_sep = self.aggregation_sep(*attended_maps) pred = self.aggregation_concat(*attended_maps) return [pred] + preds_sep
[docs] def aggregation_concat(self, *attended_maps): return self.classifier(torch.cat(attended_maps, dim=1))
[docs] def forward(self, inputs): logger.debug("SONO Net...") logger.debug("Feature Extraction:") self.debug("input", inputs) conv1 = self.conv1(inputs) self.debug("conv1", conv1) maxpool1 = self.maxpool1(conv1) self.debug("maxpool1", maxpool1) conv2 = self.conv2(maxpool1) self.debug("conv2", conv2) maxpool2 = self.maxpool2(conv2) self.debug("maxpool2", maxpool2) conv3 = self.conv3(maxpool2) self.debug("conv3", conv3) maxpool3 = self.maxpool3(conv3) self.debug("maxpool3", maxpool3) conv4 = self.conv4(maxpool3) self.debug("conv4", conv4) maxpool4 = self.maxpool4(conv4) self.debug("maxpool4", maxpool4) conv5 = self.conv5(maxpool4) self.debug("conv5", conv5) batch_size = inputs.shape[0] pooled = func.adaptive_avg_pool2d(conv5, (1, 1)).view(batch_size, -1) self.debug("pooled", pooled) logger.debug("Attention Mechanism:") g_conv1, att1 = self.compatibility_score1(conv3, conv5) self.debug("g_conv1", g_conv1) self.debug("att1", att1) g_conv2, att2 = self.compatibility_score2(conv4, conv5) self.debug("g_conv2", g_conv2) self.debug("att2", att2) logger.debug("Flatten to get single feature vector:") fsizes = self.attention_filter_sizes g1 = torch.sum(g_conv1.view(batch_size, fsizes[0], -1), dim=-1) self.debug("g1", g1) g2 = torch.sum(g_conv2.view(batch_size, fsizes[1], -1), dim=-1) self.debug("g2", g2) logger.debug("Aggregate:") out = self.aggregate(g1, g2, pooled) if self.aggregation_mode == "mean": out = [item.view(-1, self.n_classes, 1) for item in out] out = torch.cat(out, dim=2) self.debug("out", out) return out
[docs] @staticmethod def apply_softmax(pred): log_p = func.softmax(pred, dim=1) return log_p
[docs] @staticmethod def aggregate_output(output, aggregation="mean", aggregation_weight=[1, 1, 1], idx=0): """ Given a list of predictions from the model, make a decision based on aggreagation rules. """ if output.ndim == 3: logits = [] for idx in range(output.shape[2]): logits.append(SonoNet.apply_softmax( output[:, :, idx]).unsqueeze(dim=0)) logits = torch.cat(logits, dim=0) if aggregation == "max": _, pred = logits.data.max(dim=0)[0].max(dim=1) elif aggregation == "mean": _, pred = logits.mean(dim=0).max(dim=1) elif aggregation == "weighted_mean": weight_t = torch.from_numpy(np.array(weight, dtype=np.float32)) aggregation_weight = weight_t.view(-1, 1, 1).to(output.device) wlogits = (aggregation_weight.expand_as(logits) * logits) _, pred = wlogits.data.mean(dim=0).max(dim=1) else: _, pred = logits[:, :, idx].data.max(dim=1) else: logits = SonoNet.apply_softmax(output) _, pred = logits.data.max(dim=1) return pred
[docs] def debug(self, name, tensor): logger.debug(" {3}: {0} - {1} - {2}".format( tensor.shape, tensor.get_device(), tensor.dtype, name))
class _GridAttentionBlockND(nn.Module): def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode="concatenation", sub_sample_factor=(1, 1, 1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1="relu"): super(_GridAttentionBlockND, self).__init__() assert dimension in [2, 3] assert mode in ["concatenation", "concatenation_softmax", "concatenation_sigmoid", "concatenation_mean", "concatenation_range_normalise", "concatenation_mean_flow"] # Default parameter set self.mode = mode self.dimension = dimension self.sub_sample_factor = ( sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor]) * dimension) self.sub_sample_kernel_size = self.sub_sample_factor # Number of channels (pixel dimensions) self.in_channels = in_channels self.gating_channels = gating_channels self.inter_channels = inter_channels if self.inter_channels is None: self.inter_channels = in_channels // 2 if self.inter_channels == 0: self.inter_channels = 1 if dimension == 3: conv_nd = nn.Conv3d bn = nn.BatchNorm3d self.upsample_mode = "trilinear" elif dimension == 2: conv_nd = nn.Conv2d bn = nn.BatchNorm2d self.upsample_mode = "bilinear" else: raise NotImplemented # initialise id functions # Theta^T * x_ij + Phi^T * gating_signal + bias self.W = lambda x: x self.theta = lambda x: x self.psi = lambda x: x self.phi = lambda x: x self.nl1 = lambda x: x if use_W: if bn_layer: self.W = nn.Sequential( conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), bn(self.in_channels), ) else: self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) if use_theta: self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) if use_phi: self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) if use_psi: self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) if nonlinearity1: if nonlinearity1 == "relu": self.nl1 = lambda x: func.relu(x, inplace=True) if "concatenation" in mode: self.operation_function = self._concatenation else: raise NotImplementedError("Unknown operation function.") # Initialise weights for m in self.children(): init_weights(m, init_type="kaiming") if use_psi and self.mode == "concatenation_sigmoid": nn.init.constant_(self.psi.bias.data, 3.0) if use_psi and self.mode == "concatenation_softmax": nn.init.constant_(self.psi.bias.data, 10.0) # if use_psi and self.mode == "concatenation_mean": # nn.init.constant(self.psi.bias.data, 3.0) # if use_psi and self.mode == "concatenation_range_normalise": # nn.init.constant(self.psi.bias.data, 3.0) def forward(self, x, g): """ Parameters ---------- x: (b, c, t, h, w) g: (b, g_d) """ output = self.operation_function(x, g) return output def _concatenation(self, x, g): input_size = x.size() batch_size = input_size[0] assert batch_size == g.size(0) # compute compatibility score # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) # phi => (b, c, t, h, w) -> (b, i_c, t, h, w) theta_x = self.theta(x) theta_x_size = theta_x.size() # nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3) phi_g = func.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode, align_corners=True) f = theta_x + phi_g f = self.nl1(f) psi_f = self.psi(f) # normalisation -- scale compatibility score # psi^T . f -> (b, 1, t/s1, h/s2, w/s3) if self.mode == "concatenation_softmax": sigm_psi_f = func.softmax(psi_f.view(batch_size, 1, -1), dim=2) sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) elif self.mode == "concatenation_mean": psi_f_flat = psi_f.view(batch_size, 1, -1) psi_f_sum = torch.sum(psi_f_flat, dim=2) # clamp(1e-6) psi_f_sum = psi_f_sum[:, :, None].expand_as(psi_f_flat) sigm_psi_f = psi_f_flat / psi_f_sum sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) elif self.mode == "concatenation_mean_flow": psi_f_flat = psi_f.view(batch_size, 1, -1) ss = psi_f_flat.shape psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0], ss[1], 1) psi_f_flat = psi_f_flat - psi_f_min psi_f_sum = torch.sum(psi_f_flat, dim=2).view( ss[0], ss[1], 1).expand_as(psi_f_flat) sigm_psi_f = psi_f_flat / psi_f_sum sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) elif self.mode == "concatenation_range_normalise": psi_f_flat = psi_f.view(batch_size, 1, -1) ss = psi_f_flat.shape psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) sigm_psi_f = ( (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as( psi_f_flat)) sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) elif self.mode == "concatenation_sigmoid": sigm_psi_f = func.sigmoid(psi_f) else: raise NotImplementedError # sigm_psi_f is attention map! upsample the attentions and multiply sigm_psi_f = func.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode, align_corners=True) y = sigm_psi_f.expand_as(x) * x W_y = self.W(y) return W_y, sigm_psi_f
[docs]class Conv2(nn.Module):
[docs] def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): # Inheritance super(Conv2, self).__init__() # Parameters self.n = n self.ks = ks self.stride = stride self.padding = padding # Successive convolutions for cnt in range(1, n + 1): if is_batchnorm: conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, stride, padding), nn.BatchNorm2d(out_size), nn.ReLU(inplace=True)) else: conv = nn.Sequential( nn.Conv2d(in_size, out_size, ks, stride, padding), nn.ReLU(inplace=True)) setattr(self, "conv{0}".format(cnt), conv) in_size = out_size # Initialise weights for m in self.children(): init_weights(m, init_type="kaiming")
[docs] def forward(self, inputs): x = inputs for cnt in range(1, self.n + 1): conv = getattr(self, "conv{0}".format(cnt)) x = conv(x) return x
[docs]class Conv3(nn.Module):
[docs] def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3, 3, 1), padding_size=(1, 1, 0), init_stride=(1, 1, 1)): # Inheritance super(Conv3, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True)) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.BatchNorm3d(out_size), nn.ReLU(inplace=True)) else: self.conv1 = nn.Sequential( nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), nn.ReLU(inplace=True),) self.conv2 = nn.Sequential( nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), nn.ReLU(inplace=True),) # Initialise weights for m in self.children(): init_weights(m, init_type="kaiming")
[docs] def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) return outputs
[docs]class GridAttentionBlock2D(_GridAttentionBlockND):
[docs] def __init__(self, in_channels, gating_channels, inter_channels=None, mode="concatenation", sub_sample_factor=(1, 1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1="relu"): super(GridAttentionBlock2D, self).__init__( in_channels, inter_channels=inter_channels, gating_channels=gating_channels, dimension=2, mode=mode, sub_sample_factor=sub_sample_factor, bn_layer=bn_layer, use_W=use_W, use_phi=use_phi, use_theta=use_theta, use_psi=use_psi, nonlinearity1=nonlinearity1)
[docs]class GridAttentionBlock3D(_GridAttentionBlockND):
[docs] def __init__(self, in_channels, gating_channels, inter_channels=None, mode="concatenation", sub_sample_factor=(1, 1, 1), bn_layer=True): super(GridAttentionBlock3D, self).__init__( in_channels, inter_channels=inter_channels, gating_channels=gating_channels, dimension=3, mode=mode, sub_sample_factor=sub_sample_factor, bn_layer=bn_layer)
[docs]def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("Linear") != -1: init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0)
[docs]def weights_init_xavier(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: init.xavier_normal_(m.weight.data, gain=1) elif classname.find("Linear") != -1: init.xavier_normal_(m.weight.data, gain=1) elif classname.find("BatchNorm") != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0)
[docs]def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif classname.find("Linear") != -1: init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif classname.find("BatchNorm") != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0)
[docs]def weights_init_orthogonal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: init.orthogonal_(m.weight.data, gain=1) elif classname.find("Linear") != -1: init.orthogonal_(m.weight.data, gain=1) elif classname.find("BatchNorm") != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0)
[docs]def init_weights(net, init_type="normal"): if init_type == "normal": net.apply(weights_init_normal) elif init_type == "xavier": net.apply(weights_init_xavier) elif init_type == "kaiming": net.apply(weights_init_kaiming) elif init_type == "orthogonal": net.apply(weights_init_orthogonal) else: raise NotImplementedError( "Initialization method {0} is not implemented.".format(init_type))

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.