Menu

Helper Module for Deep Learning.

pynet metastasis tumor segmentation

Credit: A Grigis

pynet is a Python package related to deep learning and its application in MRI mediacal data analysis. It is accessible to everybody, and is reusable in various contexts. The project is hosted on github: https://github.com/neurospin/pynet.

In this notebook we will learn how to segment tumors using MRI images from the Brats dataset. The NvNet proposed by Andriy Myronenko’s will be trained. This network is a combination of a Vnet (3d Unet) and a VAE (variation auto-encoder).

import os
import sys
if "CI_MODE" in os.environ:
    sys.exit()

Inspect the NvNet network

Inspect some layers of the network.

from pprint import pprint
import pynet.models as models
from pynet import NetParameters
from pynet.utils import get_named_layers
from pynet.utils import setup_logging
# from pynet.plotting.network import plot_net_rescue

setup_logging(level="info")

model = models.NvNet(
    input_shape=(128, 128, 128),
    in_channels=1,
    num_classes=4,
    activation="relu",
    normalization="group_normalization",
    mode="trilinear",
    with_vae=True)
layers = get_named_layers(model)
pprint(layers)
# graph_file = plot_net_rescue(model, (1, 1, 128, 128, 128))

Import the brats dataset

Use the fetcher of the pynet package.

from pynet.datasets import DataManager, fetch_brats
from pynet.plotting import plot_data

data = fetch_brats(
    datasetdir="/neurospin/nsap/datasets/brats")
manager = DataManager(
    input_path=data.input_path,
    metadata_path=data.metadata_path,
    output_path=data.output_path,
    projection_labels=None,
    number_of_folds=10,
    batch_size=1,
    stratify_label="grade",
    # input_transforms=[
    #     RandomFlipDimensions(ndims=3, proba=0.5, with_channels=True),
    #     Offset(nb_channels=4, factor=0.1)],
    sampler="random",
    add_input=True,
    test_size=0.1,
    pin_memory=True)
dataset = manager["test"][:1]
print(dataset.inputs.shape, dataset.outputs.shape)
plot_data(dataset.inputs, channel=1, nb_samples=5)
plot_data(dataset.outputs, channel=1, nb_samples=5)

Training

From the available models load the 3D NvNet, and start the training.

import os
from torch.optim import lr_scheduler
import pynet
from pynet.losses import NvNetCombinedLoss
from pynet.interfaces import NvNetSegmenter
from pynet.plotting import plot_history
from pynet.history import History

losses = pynet.get_tools(tool_name="losses")
my_loss = losses["NvNetCombinedLoss"](
    num_classes=4,
    k1=0.1,
    k2=0.1)
outdir = "/neurospin/nsap/tmp/nvnet"
if not os.path.isdir(outdir):
    os.mkdir(outdir)
trained_model = os.path.join(outdir, "model_0_epoch_99.pth")
nvnet_params = NetParameters(
    input_shape=(150, 190, 135),
    in_channels=4,
    num_classes=4,
    activation="relu",
    normalization="group_normalization",
    mode="trilinear",
    with_vae=True)

if os.path.isfile(trained_model):
    nvnet = NvNetSegmenter(
        nvnet_params,
        optimizer_name="Adam",
        learning_rate=1e-4,
        weight_decay=1e-5,
        loss=my_loss,
        pretrained=trained_model,
        use_cuda=True)
    train_history = History.load(
        os.path.join(outdir, "train_0_epoch_9.pkl"))
    valid_history = History.load(
        os.path.join(outdir, "validation_0_epoch_9.pkl"))
else:
    nvnet = NvNetSegmenter(
        nvnet_params,
        optimizer_name="Adam",
        learning_rate=1e-4,
        weight_decay=1e-5,
        loss=my_loss,
        use_cuda=True)
    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer=nvnet.optimizer,
        mode="min",
        factor=0.5,
        patience=5)
    train_history, valid_history = nvnet.training(
        manager=manager,
        nb_epochs=100,
        checkpointdir=outdir,
        # fold_index=0,
        scheduler=scheduler,
        with_validation=True)
print(train_history)
print(valid_history)
plot_history(train_history)

Testing

Finaly use the testing set and check the results.

y_pred, X, y_true, loss, values = nvnet.testing(
    manager=manager,
    with_logit=False,
    predict=False)
print(y_pred.shape, X.shape, y_true.shape)
# y_pred = np.expand_dims(y_pred, axis=1)
# data = np.concatenate((y_pred, y_true, X), axis=1)
# plot_data(data, nb_samples=5)

if "CI_MODE" not in os.environ:
    import matplotlib.pyplot as plt
    plt.show()

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Follow us

© 2019, pynet developers .
Inspired by AZMIND template.