{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet echocardiography segmentation\n===================================\n\nCredit: A Grigis\n\npynet is a Python package related to deep learning and its application in\nMRI mediacal data analysis. It is accessible to everybody, and is reusable\nin various contexts. The project is hosted on github:\nhttps://github.com/neurospin/pynet.\n\nIn this hands-on session, we will use this U-Net architecture to segment\n2D echocardiography images. In particular, we will focus on the segmentation\nof three adjacent cardiac structures: the left ventricle, the myocardium and\nthe right ventricle. The segmentation of these ultrasound images is\nparticularly difficult due to many sources of artifacts, the recognition\nof the structures to segment, and the subjective delineation of the contours\n(e.g. at the lower part of the myocardial segmentation mask).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Import the dataset\n------------------\n\nYou may need to change the 'datasetdir' parameter.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport numpy as np\nfrom pynet.datasets import DataManager, fetch_echocardiography\nfrom pynet.plotting import plot_data\nfrom pynet.utils import setup_logging\n\nsetup_logging(level=\"info\")\n\ndata = fetch_echocardiography(\n    datasetdir=\"/tmp/echocardiography\")\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    output_path=data.output_path,\n    number_of_folds=2,\n    stratify_label=\"label\",\n    sampler=\"random\",\n    batch_size=10,\n    test_size=0.1,\n    sample_size=0.2)\ndataset = manager[\"test\"]\nprint(dataset.inputs.shape, dataset.outputs.shape)\ndata = np.concatenate((dataset.inputs, dataset.outputs), axis=1)\nplot_data(data, nb_samples=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Optimisation\n------------\n\nFrom the available models load the UNet, and start the training.\nYou may need to change the 'outdir' parameter.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nimport torch.nn as nn\nfrom pynet import NetParameters\nfrom pynet.interfaces import DeepLabNetSegmenter, PSPNetSegmenter\nfrom pynet.plotting import plot_history\nfrom pynet.history import History\n\n\ndef my_loss(x, y):\n    \"\"\" nn.CrossEntropyLoss expects a torch.LongTensor containing the class\n    indices without the channel dimension.\n    \"\"\"\n    # y = torch.sum(y, dim=1).type(torch.LongTensor)\n    device = y.get_device()\n    y = torch.argmax(y, dim=1).type(torch.LongTensor)\n    if device != -1:\n        y = y.to(device)\n    criterion = nn.CrossEntropyLoss()\n    return criterion(x, y)\n\n\noutdir = \"/tmp/echocardiography\"\nmodel = \"pspnet\"\nif model == \"pspnet\":\n    params = NetParameters(\n        n_classes=4,\n        sizes=(1, 2, 3, 6),\n        psp_size=512,\n        deep_features_size=256,\n        backend=\"resnet18\",\n        drop_rate=0)\n    net = PSPNetSegmenter(\n        params,\n        optimizer_name=\"Adam\",\n        learning_rate=5e-4,\n        metrics=[\"multiclass_dice\"],\n        loss=my_loss,\n        use_cuda=False)\nelse:\n    params = NetParameters(\n        n_classes=4,\n        drop_rate=0)\n    net = DeepLabNetSegmenter(\n        params,\n        optimizer_name=\"Adam\",\n        learning_rate=5e-4,\n        metrics=[\"multiclass_dice\"],\n        loss=my_loss,\n        use_cuda=False)\nprint(net.model)\ntrain_history, valid_history = net.training(\n    manager=manager,\n    nb_epochs=10,\n    checkpointdir=None,\n    fold_index=0,\n    with_validation=True)\nprint(train_history)\nprint(valid_history)\nplot_history(train_history)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Testing\n-------\n\nFinaly use the testing set and check the results.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y_pred, X, y_true, loss, values = net.testing(\n    manager=manager,\n    with_logit=True,\n    predict=True)\nprint(y_pred.shape, X.shape, y_true.shape)\ny_pred = np.expand_dims(y_pred, axis=1)\ndata = np.concatenate((y_pred, y_true, X), axis=1)\nplot_data(data, nb_samples=5, random=False)\n\nif \"CI_MODE\" not in os.environ:\n    import matplotlib.pyplot as plt\n    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}