Helper Module for Deep Learning.
Source code for pynet.models.spherical.unet
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Module that provides the spherical UNet architecture.
"""
# Imports
import logging
from collections import namedtuple
import torch
import numpy as np
import torch.nn as nn
from joblib import Memory
from .sampling import (
icosahedron, neighbors, number_of_ico_vertices, downsample, interpolate,
neighbors_rec)
from .layers import (
IcoUpConvLayer, IcoUpSampleMaxIndexLayer, IcoUpSampleFixIndexLayer,
IcoUpSampleLayer, IcoPoolLayer, DiNeIcoConvLayer, RePaIcoConvLayer)
from .utils import debug
from pynet.interfaces import DeepLearningDecorator
from pynet.utils import Networks
# Global parameters
logger = logging.getLogger("pynet")
Ico = namedtuple("Ico", ["order", "vertices", "triangles", "neighbor_indices",
"down_indices", "up_indices",
"conv_neighbor_indices"])
[docs]@Networks.register
@DeepLearningDecorator(family=("encoder", ))
class SphericalUNet(nn.Module):
""" Define the Spherical UNet structure.
https://github.com/zhaofenqiang/Spherical_U-Net
Spherical U-Net on Cortical Surfaces: Methods and Applications
"""
[docs] def __init__(self, in_order, in_channels, out_channels, depth=5,
start_filts=32, conv_mode="1ring", up_mode="interp",
cachedir=None):
""" Initialize the Spherical UNet.
Parameters
----------
in_order: int
the input icosahedron order.
in_channels: int
input features/channels.
out_channels: int
output features/channels.
depth: int, default 5
number of layers in the UNet.
start_filts: int, default 32
number of convolutional filters for the first conv.
conv_mode: str, default '1ring'
the size of the spherical convolution filter: '1ring' or '2ring'.
Can also use rectangular grid projection: 'repa'.
up_mode: str, default 'interp'
type of upsampling: 'transpose' for transpose
convolution (1 ring), 'interp' for nearest neighbor linear
interpolation, 'maxpad' for max pooling shifted zero padding,
and 'zeropad' for classical zero padding.
cachedir: str, default None
set tthis folder tu use smart caching speedup.
"""
logger.debug("SphericalUNet init...")
super(SphericalUNet, self).__init__()
self.memory = Memory(cachedir, verbose=0)
self.in_order = in_order
self.depth = depth
self.conv_mode = conv_mode
self.in_vertices = number_of_ico_vertices(order=in_order)
self.in_channels = in_channels
self.out_channels = out_channels
self.up_mode = up_mode
self.ico = {}
icosahedron_cached = self.memory.cache(icosahedron)
neighbors_cached = self.memory.cache(neighbors)
neighbors_rec_cached = self.memory.cache(neighbors_rec)
for order in range(1, in_order + 1):
vertices, triangles = icosahedron_cached(order=order)
logger.debug("- ico {0}: verts {1} - tris {2}".format(
order, vertices.shape, triangles.shape))
neighs = neighbors_cached(
vertices, triangles, depth=1, direct_neighbor=True)
neighs = np.asarray(list(neighs.values()))
logger.debug("- neighbors {0}: {1}".format(
order, neighs.shape))
if conv_mode == "1ring":
conv_neighs = neighs
logger.debug("- neighbors {0}: {1}".format(
order, conv_neighs.shape))
elif conv_mode == "2ring":
conv_neighs = neighbors_cached(
vertices, triangles, depth=2, direct_neighbor=True)
conv_neighs = np.asarray(list(conv_neighs.values()))
logger.debug("- neighbors {0}: {1}".format(
order, conv_neighs.shape))
elif conv_mode == "repa":
conv_neighs, conv_weights, _ = neighbors_rec_cached(
vertices, triangles, size=5, zoom=5)
logger.debug("- neighbors {0}: {1} - {2}".format(
order, conv_neighs.shape, conv_weights.shape))
conv_neighs = (conv_neighs, conv_weights)
else:
raise ValueError("Unexptected convolution mode.")
self.ico[order] = Ico(
order=order, vertices=vertices, triangles=triangles,
neighbor_indices=neighs, down_indices=None, up_indices=None,
conv_neighbor_indices=conv_neighs)
downsample_cached = self.memory.cache(downsample)
for order in range(in_order, 1, -1):
down_indices = downsample_cached(
self.ico[order].vertices, self.ico[order - 1].vertices)
logger.debug("- down {0}: {1}".format(order, down_indices.shape))
self.ico[order] = self.ico[order]._replace(
down_indices=down_indices)
interpolate_cached = self.memory.cache(interpolate)
for order in range(1, in_order):
up_indices = interpolate_cached(
self.ico[order].vertices, self.ico[order + 1].vertices,
self.ico[order + 1].triangles)
up_indices = np.asarray(list(up_indices.values()))
logger.debug("- up {0}: {1}".format(order, up_indices.shape))
self.ico[order] = self.ico[order]._replace(
up_indices=up_indices)
self.filts = [in_channels] + [
start_filts * 2 ** idx for idx in range(depth)]
logger.debug("- filters: {0}".format(self.filts))
if conv_mode == "repa":
self.sconv = RePaIcoConvLayer
else:
self.sconv = DiNeIcoConvLayer
for idx in range(depth):
order = self.in_order - idx
logger.debug(
"- DownBlock {0}: {1} -> {2} [{3} - {4} - {5}]".format(
idx, self.filts[idx], self.filts[idx + 1],
self.ico[order].neighbor_indices.shape,
(None if idx == 0
else self.ico[order + 1].neighbor_indices.shape),
(None if idx == 0
else self.ico[order + 1].down_indices.shape)))
block = DownBlock(
conv_layer=self.sconv,
in_ch=self.filts[idx],
out_ch=self.filts[idx + 1],
conv_neigh_indices=self.ico[order].conv_neighbor_indices,
down_neigh_indices=(
None if idx == 0
else self.ico[order + 1].neighbor_indices),
down_indices=(
None if idx == 0
else self.ico[order + 1].down_indices),
pool_mode=("max" if self.up_mode == "maxpad" else "mean"),
first=(True if idx == 0 else False))
setattr(self, "down{0}".format(idx + 1), block)
cnt = 1
for idx in range(depth - 1, 0, -1):
logger.debug("- UpBlock {0}: {1} -> {2} [{3} - {4}]".format(
cnt, self.filts[idx + 1], self.filts[idx],
self.ico[order + 1].neighbor_indices.shape,
self.ico[order].up_indices.shape))
block = UpBlock(
conv_layer=self.sconv,
in_ch=self.filts[idx + 1],
out_ch=self.filts[idx],
conv_neigh_indices=self.ico[order + 1].conv_neighbor_indices,
neigh_indices=self.ico[order + 1].neighbor_indices,
up_neigh_indices=self.ico[order].up_indices,
down_indices=self.ico[order + 1].down_indices,
up_mode=self.up_mode)
setattr(self, "up{0}".format(cnt), block)
order += 1
cnt += 1
logger.debug("- FC: {0} -> {1}".format(self.filts[1], out_channels))
self.fc = nn.Sequential(
nn.Linear(self.filts[1], out_channels))
[docs] def forward(self, x):
logger.debug("SphericalUNet...")
debug("input", x)
if x.size(2) != self.in_vertices:
raise RuntimeError("Input data must be projected on an {0} order "
"icosahedron.".format(self.in_order))
encoder_outs = []
pooling_outs = []
for idx in range(1, self.depth + 1):
down_block = getattr(self, "down{0}".format(idx))
logger.debug("- filter {0}: {1}".format(idx, down_block))
x, max_pool_indices = down_block(x)
encoder_outs.append(x)
pooling_outs.append(max_pool_indices)
encoder_outs = encoder_outs[::-1]
pooling_outs = pooling_outs[::-1]
for idx in range(1, self.depth):
up_block = getattr(self, "up{0}".format(idx))
logger.debug("- filter {0}: {1}".format(idx, up_block))
x_up = encoder_outs[idx]
max_pool_indices = pooling_outs[idx - 1]
x = up_block(x, x_up, max_pool_indices)
logger.debug("FC...")
debug("input", x)
n_samples = len(x)
x = x.permute(0, 2, 1)
x = x.reshape(n_samples * self.in_vertices, self.filts[1])
x = self.fc(x)
x = x.view(n_samples, self.in_vertices, self.out_channels)
x = x.permute(0, 2, 1)
debug("output", x)
return x
[docs]class DownBlock(nn.Module):
""" Downsampling block in spherical UNet:
mean pooling => (conv => BN => ReLU) * 2
"""
[docs] def __init__(self, conv_layer, in_ch, out_ch, conv_neigh_indices,
down_neigh_indices, down_indices, pool_mode="mean",
first=False):
""" Init.
Parameters
----------
conv_layer: nn.Module
the convolutional layer on icosahedron discretized sphere.
in_ch: int
input features/channels.
out_ch: int
output features/channels.
conv_neigh_indices: array
conv layer's filters' neighborhood indices at sampling i.
down_neigh_indices: array
conv layer's filters' neighborhood indices at sampling i + 1.
down_indices: array
downsampling indices at sampling i.
pool_mode: str, default 'mean'
the pooling mode: 'mean' or 'max'.
first: bool, default False
if set skip the pooling block.
"""
super(DownBlock, self).__init__()
self.first = first
if not first:
self.pooling = IcoPoolLayer(
down_neigh_indices, down_indices, pool_mode)
self.block = nn.Sequential(
conv_layer(in_ch, out_ch, conv_neigh_indices),
nn.BatchNorm1d(out_ch, momentum=0.15, affine=True,
track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
conv_layer(out_ch, out_ch, conv_neigh_indices),
nn.BatchNorm1d(out_ch, momentum=0.15, affine=True,
track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x):
logger.debug("- DownBlock")
debug("input", x)
max_pool_indices = None
if not self.first:
x, max_pool_indices = self.pooling(x)
debug("pooling", x)
if max_pool_indices is not None:
debug("max pooling indices", max_pool_indices)
x = self.block(x)
debug("output", x)
return x, max_pool_indices
[docs]class UpBlock(nn.Module):
""" Define the upsamping block in spherical UNet:
upconv => (conv => BN => ReLU) * 2
"""
[docs] def __init__(self, conv_layer, in_ch, out_ch, conv_neigh_indices,
neigh_indices, up_neigh_indices, down_indices, up_mode):
""" Init.
Parameters
----------
conv_layer: nn.Module
the convolutional layer on icosahedron discretized sphere.
in_ch: int
input features/channels.
out_ch: int
output features/channels.
conv_neigh_indices: tensor, int
conv layer's filters' neighborhood indices at sampling i.
neigh_indices: tensor, int
neighborhood indices at sampling i.
up_neigh_indices: array
upsampling neighborhood indices at sampling i + 1.
down_indices: array
downsampling indices at sampling i.
up_mode: str, default 'interp'
type of upsampling: 'transpose' for transpose
convolution, 'interp' for nearest neighbor linear interpolation,
'maxpad' for max pooling shifted zero padding, and 'zeropad' for
classical zero padding.
"""
super(UpBlock, self).__init__()
self.up_mode = up_mode
if up_mode == "interp":
self.up = IcoUpSampleLayer(in_ch, out_ch, up_neigh_indices)
elif up_mode == "zeropad":
self.up = IcoUpSampleFixIndexLayer(in_ch, out_ch, up_neigh_indices)
elif up_mode == "maxpad":
self.up = IcoUpSampleMaxIndexLayer(
in_ch, out_ch, neigh_indices, down_indices)
elif up_mode == "transpose":
self.up = IcoUpConvLayer(
in_ch, out_ch, neigh_indices, down_indices)
else:
raise ValueError("Invalid upsampling method.")
self.double_conv = nn.Sequential(
conv_layer(in_ch, out_ch, conv_neigh_indices),
nn.BatchNorm1d(out_ch, momentum=0.15, affine=True,
track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
conv_layer(out_ch, out_ch, conv_neigh_indices),
nn.BatchNorm1d(out_ch, momentum=0.15, affine=True,
track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x1, x2, max_pool_indices):
logger.debug("- UpBlock")
debug("input", x1)
debug("skip", x2)
if self.up_mode == "maxpad":
x1 = self.up(x1, max_pool_indices)
else:
x1 = self.up(x1)
debug("upsampling", x1)
x = torch.cat((x1, x2), 1)
debug("cat", x)
x = self.double_conv(x)
debug("output", x)
return x
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.