Helper Module for Deep Learning.
pynet echocardiography segmentation¶
Credit: A Grigis
pynet is a Python package related to deep learning and its application in MRI mediacal data analysis. It is accessible to everybody, and is reusable in various contexts. The project is hosted on github: https://github.com/neurospin/pynet.
In this hands-on session, we will use this U-Net architecture to segment 2D echocardiography images. In particular, we will focus on the segmentation of three adjacent cardiac structures: the left ventricle, the myocardium and the right ventricle. The segmentation of these ultrasound images is particularly difficult due to many sources of artifacts, the recognition of the structures to segment, and the subjective delineation of the contours (e.g. at the lower part of the myocardial segmentation mask).
Import the dataset¶
You may need to change the ‘datasetdir’ parameter.
import os
import numpy as np
from pynet.datasets import DataManager, fetch_echocardiography
from pynet.plotting import plot_data
from pynet.utils import setup_logging
setup_logging(level="info")
data = fetch_echocardiography(
datasetdir="/tmp/echocardiography")
manager = DataManager(
input_path=data.input_path,
metadata_path=data.metadata_path,
output_path=data.output_path,
number_of_folds=2,
stratify_label="label",
sampler="weighted_random",
batch_size=10,
test_size=0.1,
sample_size=(1 if "CI_MODE" not in os.environ else 0.05))
dataset = manager["test"]
print(dataset.inputs.shape, dataset.outputs.shape)
data = np.concatenate((dataset.inputs, dataset.outputs), axis=1)
plot_data(data, nb_samples=5)
Optimisation¶
From the available models load the UNet, and start the training. You may need to change the ‘outdir’ parameter.
import torch
import torch.nn as nn
from pynet import NetParameters
from pynet.interfaces import UNetEncoder
from pynet.plotting import plot_history
from pynet.history import History
def my_loss(x, y):
""" nn.CrossEntropyLoss expects a torch.LongTensor containing the class
indices without the channel dimension.
"""
# y = torch.sum(y, dim=1).type(torch.LongTensor)
device = y.get_device()
y = torch.argmax(y, dim=1).type(torch.LongTensor)
if device != -1:
y = y.to(device)
criterion = nn.CrossEntropyLoss()
return criterion(x, y)
outdir = "/tmp/echocardiography"
trained_model = os.path.join(outdir, "model_0_epoch_9.pth")
unet_params = NetParameters(
num_classes=4,
in_channels=1,
depth=5,
start_filts=16,
up_mode="upsample",
merge_mode="concat",
batchnorm=False,
dim="2d")
if os.path.isfile(trained_model):
unet = UNetEncoder(
unet_params,
optimizer_name="Adam",
learning_rate=5e-4,
metrics=["multiclass_dice"],
loss=my_loss,
pretrained=trained_model,
use_cuda=False)
train_history = History.load(
os.path.join(outdir, "train_0_epoch_9.pkl"))
valid_history = History.load(
os.path.join(outdir, "validation_0_epoch_9.pkl"))
else:
unet = UNetEncoder(
unet_params,
optimizer_name="Adam",
learning_rate=5e-4,
metrics=["multiclass_dice"],
loss=my_loss,
use_cuda=False)
print(unet.model)
train_history, valid_history = unet.training(
manager=manager,
nb_epochs=(10 if "CI_MODE" not in os.environ else 1),
checkpointdir=outdir,
fold_index=0,
with_validation=True)
print(train_history)
print(valid_history)
plot_history(train_history)
Testing¶
Finaly use the testing set and check the results.
y_pred, X, y_true, loss, values = unet.testing(
manager=manager,
with_logit=True,
predict=True)
print(y_pred.shape, X.shape, y_true.shape)
y_pred = np.expand_dims(y_pred, axis=1)
data = np.concatenate((y_pred, y_true, X), axis=1)
plot_data(data, nb_samples=5, random=False)
if "CI_MODE" not in os.environ:
import matplotlib.pyplot as plt
plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)
Gallery generated by Sphinx-Gallery
Follow us
Inspired by AZMIND template.