Helper Module for Deep Learning.
pynet: predicting autism¶
Credit: A Grigis
This practice is based on the IMPAC challenge, https://paris-saclay-cds.github.io/autism_challenge.
Autism spectrum disorder (ASD) is a severe psychiatric disorder that affects 1 in 166 children. In the IMPAC challenge ML models were trained using the database’s derived anatomical and functional features to diagnose a subject as autistic or healthy. We propose here to implement the best neural network to achieve this task and proposed in https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6859452, ie. a dense feedforward network.
import os
import sys
if "CI_MODE" in os.environ:
sys.exit()
import logging
import pynet
from pynet.datasets import fetch_impac
from pynet.datasets import DataManager
from pynet.utils import setup_logging
from pynet.plotting import Board, update_board
from pynet.interfaces import DeepLearningInterface
from pynet.metrics import SKMetrics
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc
import collections
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import pandas as pd
from scipy.stats import spearmanr
import seaborn as sns
import matplotlib.pyplot as plt
setup_logging(level="info")
logger = logging.getLogger("pynet")
use_toy = False
dtype = "fmri"
data = fetch_impac(
datasetdir="/neurospin/nsap/datasets/impac",
mode="train",
dtype=dtype)
nb_features = data.nb_features
manager = DataManager(
input_path=data.input_path,
labels=["participants_asd"],
metadata_path=data.metadata_path,
number_of_folds=10,
batch_size=128,
sampler="random",
test_size=0,
sample_size=1)
if use_toy:
toy_data = {}
nb_features = 50
for name, nb_samples in (("train", 1000), ("test", 2)):
x1 = torch.randn(nb_samples, 50)
x2 = torch.randn(nb_samples, 50) + 1.5
x = torch.cat([x1, x2], dim=0)
y1 = torch.zeros(nb_samples, 1)
y2 = torch.ones(nb_samples, 1)
y = torch.cat([y1, y2], dim=0)
toy_data[name] = (x, y)
if name == "train":
plt.figure()
plt.scatter(x1[:, 0], x1[:, 1], color="b")
plt.scatter(x2[:, 0], x2[:, 1], color="r")
manager = DataManager.from_numpy(
train_inputs=toy_data["train"][0], train_labels=toy_data["train"][1],
batch_size=50, test_inputs=toy_data["test"][0],
test_labels=toy_data["test"][1])
class DenseFeedForwardNet(nn.Module):
def __init__(self, nb_features):
""" Initialize the instance.
Parameters
----------
nb_features: int
the size of the feature vector.
"""
super(DenseFeedForwardNet, self).__init__()
self.layers = nn.Sequential(collections.OrderedDict([
("linear1", nn.Linear(nb_features, 64)),
("relu1", nn.LeakyReLU(negative_slope=0.01)),
("dropout", nn.Dropout(0.13)),
("linear2", nn.Linear(64, 64)),
("relu2", nn.LeakyReLU(negative_slope=0.01)),
("linear3", nn.Linear(64, 1))
]))
self.layers_alt = nn.Sequential(collections.OrderedDict([
("linear1", nn.Linear(nb_features, 50)),
("relu1", nn.ReLU()),
("dropout", nn.Dropout(0.2)),
("linear2", nn.Linear(50, 100)),
("relu2", nn.PReLU(1)),
("linear3", nn.Linear(100, 1))
]))
def forward(self, x):
return self.layers(x)
def my_loss(x, y):
logger.debug("Binary cross-entropy loss...")
device = y.get_device()
criterion = nn.BCEWithLogitsLoss()
x = x.view(-1, 1)
y = y.view(-1, 1)
y = y.type(torch.float32)
if device != -1:
y = y.to(device)
logger.debug(" x: {0} - {1}".format(x.shape, x.dtype))
logger.debug(" y: {0} - {1}".format(y.shape, y.dtype))
return criterion(x, y)
def plot_metric_rank_correlations(metrics):
""" Display rank correlations for all numerical metrics calculated over N
experiments.
Parameters
----------
metrics: DataFrame
a data frame with all computedd metrics as columns and N rows.
"""
fig, ax = plt.subplots()
labels = metrics.columns
sns.heatmap(spearmanr(metrics)[0], annot=True, cmap=plt.get_cmap("Blues"),
xticklabels=labels, yticklabels=labels, ax=ax)
model = DenseFeedForwardNet(nb_features)
print(model)
cl = DeepLearningInterface(
optimizer_name="Adam",
learning_rate=1e-4,
weight_decay=1.1e-4,
metrics=["binary_accuracy", "sk_roc_auc_score"],
loss=my_loss,
model=model)
cl.board = Board(port=8097, host="http://localhost", env="main")
cl.add_observer("after_epoch", update_board)
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer=cl.optimizer,
mode="min",
factor=0.1,
patience=5,
verbose=True,
eps=1e-8)
test_history, train_history = cl.training(
manager=manager,
nb_epochs=50,
checkpointdir=None,
fold_index=0,
scheduler=scheduler,
with_validation=(not use_toy))
if not use_toy:
data = fetch_impac(
datasetdir="/neurospin/nsap/datasets/impac",
mode="test",
dtype=dtype)
manager = DataManager(
input_path=data.input_path,
labels=["participants_asd"],
metadata_path=data.metadata_path,
number_of_folds=3,
batch_size=10,
sampler=None,
test_size=1,
sample_size=1)
y_pred, X, y_true, loss, values = cl.testing(
manager=manager,
with_logit=True,
logit_function="sigmoid",
predict=False)
result = pd.DataFrame.from_dict(collections.OrderedDict([
("pred", (y_pred.squeeze() > 0.5).astype(int)),
("truth", y_true.squeeze()),
("prob", y_pred.squeeze())]))
print(result)
fig, ax = plt.subplots()
cmap = plt.get_cmap('Blues')
cm = SKMetrics("confusion_matrix", with_logit=False)(y_pred, y_true)
sns.heatmap(cm, cmap=cmap, annot=True, fmt="g", ax=ax)
ax.set_xlabel("predicted values")
ax.set_ylabel("actual values")
metrics = {}
sk_metrics = dict(
(key, val) for key, val in pynet.get_tools()["metrics"].items()
if key.startswith("sk_"))
for name, metric in sk_metrics.items():
metric.with_logit = False
value = metric(y_pred, y_true)
metrics.setdefault(name, []).append(value)
metrics = pd.DataFrame.from_dict(metrics)
print(classification_report(y_true, y_pred >= 0.4))
print(metrics)
# plot_metric_rank_correlations(metrics)
fpr, tpr, _ = roc_curve(y_true, y_pred)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color="darkorange", lw=2,
label="ROC curve (area = %0.2f)" % roc_auc)
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver operating characteristic")
plt.legend(loc="lower right")
plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)
Gallery generated by Sphinx-Gallery
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.