{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet: predicting autism\n========================\n\nCredit: A Grigis\n\nThis practice is based on the IMPAC challenge,\nhttps://paris-saclay-cds.github.io/autism_challenge.\n\nAutism spectrum disorder (ASD) is a severe psychiatric disorder that affects\n1 in 166 children. In the IMPAC challenge ML models were trained using the\ndatabase's derived anatomical and functional features to diagnose a subject\nas autistic or healthy. We propose here to implement the best neural network\nto achieve this task and proposed in\nhttps://www.ncbi.nlm.nih.gov/pmc/articles/PMC6859452, ie. a dense feedforward\nnetwork.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\nimport logging\nimport pynet\nfrom pynet.datasets import fetch_impac\nfrom pynet.datasets import DataManager\nfrom pynet.utils import setup_logging\nfrom pynet.plotting import Board, update_board\nfrom pynet.interfaces import DeepLearningInterface\nfrom pynet.metrics import SKMetrics\nfrom sklearn.metrics import classification_report\nfrom sklearn.metrics import roc_curve, auc\nimport collections\nimport torch\nimport torch.nn as nn\nfrom torch.optim import lr_scheduler\nimport pandas as pd\nfrom scipy.stats import spearmanr\nimport seaborn as sns\nimport matplotlib.pyplot as plt\n\nsetup_logging(level=\"info\")\nlogger = logging.getLogger(\"pynet\")\n\nuse_toy = False\ndtype = \"fmri\"\n\ndata = fetch_impac(\n    datasetdir=\"/neurospin/nsap/datasets/impac\",\n    mode=\"train\",\n    dtype=dtype)\nnb_features = data.nb_features\nmanager = DataManager(\n    input_path=data.input_path,\n    labels=[\"participants_asd\"],\n    metadata_path=data.metadata_path,\n    number_of_folds=10,\n    batch_size=128,\n    sampler=\"random\",\n    test_size=0,\n    sample_size=1)\n\nif use_toy:\n    toy_data = {}\n    nb_features = 50\n    for name, nb_samples in ((\"train\", 1000), (\"test\", 2)):\n        x1 = torch.randn(nb_samples, 50)\n        x2 = torch.randn(nb_samples, 50) + 1.5\n        x = torch.cat([x1, x2], dim=0)\n        y1 = torch.zeros(nb_samples, 1)\n        y2 = torch.ones(nb_samples, 1)\n        y = torch.cat([y1, y2], dim=0)\n        toy_data[name] = (x, y)\n        if name == \"train\":\n            plt.figure()\n            plt.scatter(x1[:, 0], x1[:, 1], color=\"b\")\n            plt.scatter(x2[:, 0], x2[:, 1], color=\"r\")\n    manager = DataManager.from_numpy(\n        train_inputs=toy_data[\"train\"][0], train_labels=toy_data[\"train\"][1],\n        batch_size=50, test_inputs=toy_data[\"test\"][0],\n        test_labels=toy_data[\"test\"][1])\n\n\nclass DenseFeedForwardNet(nn.Module):\n    def __init__(self, nb_features):\n        \"\"\" Initialize the instance.\n\n        Parameters\n        ----------\n        nb_features: int\n            the size of the feature vector.\n        \"\"\"\n        super(DenseFeedForwardNet, self).__init__()\n        self.layers = nn.Sequential(collections.OrderedDict([\n            (\"linear1\", nn.Linear(nb_features, 64)),\n            (\"relu1\", nn.LeakyReLU(negative_slope=0.01)),\n            (\"dropout\", nn.Dropout(0.13)),\n            (\"linear2\", nn.Linear(64, 64)),\n            (\"relu2\", nn.LeakyReLU(negative_slope=0.01)),\n            (\"linear3\", nn.Linear(64, 1))\n        ]))\n        self.layers_alt = nn.Sequential(collections.OrderedDict([\n            (\"linear1\", nn.Linear(nb_features, 50)),\n            (\"relu1\", nn.ReLU()),\n            (\"dropout\", nn.Dropout(0.2)),\n            (\"linear2\", nn.Linear(50, 100)),\n            (\"relu2\", nn.PReLU(1)),\n            (\"linear3\", nn.Linear(100, 1))\n        ]))\n\n    def forward(self, x):\n        return self.layers(x)\n\n\ndef my_loss(x, y):\n    logger.debug(\"Binary cross-entropy loss...\")\n    device = y.get_device()\n    criterion = nn.BCEWithLogitsLoss()\n    x = x.view(-1, 1)\n    y = y.view(-1, 1)\n    y = y.type(torch.float32)\n    if device != -1:\n        y = y.to(device)\n    logger.debug(\"  x: {0} - {1}\".format(x.shape, x.dtype))\n    logger.debug(\"  y: {0} - {1}\".format(y.shape, y.dtype))\n    return criterion(x, y)\n\n\ndef plot_metric_rank_correlations(metrics):\n    \"\"\" Display rank correlations for all numerical metrics calculated over N\n    experiments.\n\n    Parameters\n    ----------\n    metrics: DataFrame\n        a data frame with all computedd metrics as columns and N rows.\n    \"\"\"\n    fig, ax = plt.subplots()\n    labels = metrics.columns\n    sns.heatmap(spearmanr(metrics)[0], annot=True, cmap=plt.get_cmap(\"Blues\"),\n                xticklabels=labels, yticklabels=labels, ax=ax)\n\n\nmodel = DenseFeedForwardNet(nb_features)\nprint(model)\ncl = DeepLearningInterface(\n    optimizer_name=\"Adam\",\n    learning_rate=1e-4,\n    weight_decay=1.1e-4,\n    metrics=[\"binary_accuracy\", \"sk_roc_auc_score\"],\n    loss=my_loss,\n    model=model)\ncl.board = Board(port=8097, host=\"http://localhost\", env=\"main\")\ncl.add_observer(\"after_epoch\", update_board)\nscheduler = lr_scheduler.ReduceLROnPlateau(\n    optimizer=cl.optimizer,\n    mode=\"min\",\n    factor=0.1,\n    patience=5,\n    verbose=True,\n    eps=1e-8)\ntest_history, train_history = cl.training(\n    manager=manager,\n    nb_epochs=50,\n    checkpointdir=None,\n    fold_index=0,\n    scheduler=scheduler,\n    with_validation=(not use_toy))\n\nif not use_toy:\n    data = fetch_impac(\n        datasetdir=\"/neurospin/nsap/datasets/impac\",\n        mode=\"test\",\n        dtype=dtype)\n    manager = DataManager(\n        input_path=data.input_path,\n        labels=[\"participants_asd\"],\n        metadata_path=data.metadata_path,\n        number_of_folds=3,\n        batch_size=10,\n        sampler=None,\n        test_size=1,\n        sample_size=1)\ny_pred, X, y_true, loss, values = cl.testing(\n    manager=manager,\n    with_logit=True,\n    logit_function=\"sigmoid\",\n    predict=False)\nresult = pd.DataFrame.from_dict(collections.OrderedDict([\n    (\"pred\", (y_pred.squeeze() > 0.5).astype(int)),\n    (\"truth\", y_true.squeeze()),\n    (\"prob\", y_pred.squeeze())]))\nprint(result)\nfig, ax = plt.subplots()\ncmap = plt.get_cmap('Blues')\ncm = SKMetrics(\"confusion_matrix\", with_logit=False)(y_pred, y_true)\nsns.heatmap(cm, cmap=cmap, annot=True, fmt=\"g\", ax=ax)\nax.set_xlabel(\"predicted values\")\nax.set_ylabel(\"actual values\")\nmetrics = {}\nsk_metrics = dict(\n    (key, val) for key, val in pynet.get_tools()[\"metrics\"].items()\n    if key.startswith(\"sk_\"))\nfor name, metric in sk_metrics.items():\n    metric.with_logit = False\n    value = metric(y_pred, y_true)\n    metrics.setdefault(name, []).append(value)\nmetrics = pd.DataFrame.from_dict(metrics)\nprint(classification_report(y_true, y_pred >= 0.4))\nprint(metrics)\n# plot_metric_rank_correlations(metrics)\nfpr, tpr, _ = roc_curve(y_true, y_pred)\nroc_auc = auc(fpr, tpr)\nplt.figure()\nplt.plot(fpr, tpr, color=\"darkorange\", lw=2,\n         label=\"ROC curve (area = %0.2f)\" % roc_auc)\nplt.plot([0, 1], [0, 1], color=\"navy\", lw=2, linestyle=\"--\")\nplt.xlim([0.0, 1.0])\nplt.ylim([0.0, 1.05])\nplt.xlabel(\"False Positive Rate\")\nplt.ylabel(\"True Positive Rate\")\nplt.title(\"Receiver operating characteristic\")\nplt.legend(loc=\"lower right\")\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}