{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet data augmentation overview\n================================\n\nCredit: A Grigis\n\npynet contains a set of tools to efficiently augment 3D medical images that\nis crutial for deep learning applications. It includes random affine/non linear\ntransformations, simulation of intensity artifacts due to MRI magnetic field\ninhomogeneity or k-space motion artifacts, and others.\n\nLoad the data\n-------------\n\nWe load the Brats dataset and select the first MRI brain image.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nimport time\nimport numpy as np\nimport nibabel\nimport random\nfrom pynet.datasets import DataManager, fetch_toy, fetch_brats\nfrom pynet.preprocessing import rescale, downsample\n\ndatasetdir = \"/tmp/toy\"\ndata = fetch_toy(datasetdir=datasetdir)\nimage = nibabel.load(data.t1w_path)\nimage = rescale(downsample(image.get_data(), scale=4), dynamic=(0, 255))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Define deformations\n-------------------\n\nWe now declare MRI brain deformation functions. The deformation can be\ncombined with the Transformer class.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pynet.augmentation import add_blur\nfrom pynet.augmentation import add_noise\nfrom pynet.augmentation import add_ghosting\nfrom pynet.augmentation import add_spike\nfrom pynet.augmentation import add_biasfield\nfrom pynet.augmentation import add_motion\nfrom pynet.augmentation import add_offset\nfrom pynet.augmentation import flip\nfrom pynet.augmentation import affine\nfrom pynet.augmentation import deformation\nfrom pynet.augmentation import Transformer\n\ncompose_transforms = Transformer(with_channel=False)\ncompose_transforms.register(\n    flip, probability=0.5, axis=0, apply_to=[\"all\"])\ncompose_transforms.register(\n    add_blur, probability=1, sigma=4, apply_to=[\"all\"])\ntransforms = {\n    \"add_blur\": (add_blur, {\"sigma\": 4}),\n    \"add_noise\": (add_noise, {\"snr\": 5., \"noise_type\": \"rician\"}),\n    \"flip\": (flip, {\"axis\": 0}),\n    \"affine\": (affine, {\"rotation\": 5, \"translation\": 0, \"zoom\": 0.05}),\n    \"add_ghosting\": (add_ghosting, {\"n_ghosts\": (4, 10), \"axis\": 2,\n                                   \"intensity\": (0.5, 1)}),\n    \"add_spike\": (add_spike, {\"n_spikes\": 1, \"intensity\": (0.1, 1)}),\n    \"add_biasfield\": (add_biasfield, {\"coefficients\": 0.5}),\n    \"deformation\": (deformation, {\"max_displacement\": 4, \"alpha\": 3}),\n    \"add_motion\": (add_motion, {\"rotation\": 10, \"translation\": 10,\n                                \"n_transforms\": 2, \"perturbation\": 0.3}),\n    \"add_offset\": (add_offset, {\"factor\": (0.05, 0.1)}),\n    \"compose_transforms\": (compose_transforms, {}),\n}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Test transformations\n--------------------\n\nWe now apply the transformations on the loaded image. Results are\ndirectly displayed in your browser at http://localhost:8097.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pynet.plotting import Board\n\nboard = Board(port=8097, host=\"http://localhost\", env=\"data-augmentation\")\nfor cnt in range(10):\n    print(\"Iteration: \", cnt)\n    for key, (fct, kwargs) in transforms.items():\n        images = np.asarray([image, np.clip(fct(image, **kwargs), 0, 255)])\n        images = images[..., images.shape[-1] // 2]\n        images = np.expand_dims(images, axis=1)\n        board.viewer.images(\n            images, opts={\"title\": key, \"caption\": key}, win=key)\n    time.sleep(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Data augmentation\n-----------------\n\nWe now illustrate how we can use the Transformer in combinaison with\nthe DataManager to perform data augmentation during training. Results are\ndirectly displayed in your browser at http://localhost:8097.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "datasetdir = \"/neurospin/nsap/processed/deepbrain/tumor/data/brats\"\ndata = fetch_brats(datasetdir=datasetdir)\n\nboard = Board(port=8097, host=\"http://localhost\", env=\"data-augmentation\")\ncompose_transforms = Transformer()\ncompose_transforms.register(\n    flip, probability=0.5, axis=0, apply_to=[\"input\", \"output\"])\ncompose_transforms.register(\n    add_blur, probability=1, sigma=4, apply_to=[\"input\"])\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    output_path=data.output_path,\n    number_of_folds=2,\n    batch_size=2,\n    test_size=0.1,\n    sample_size=0.1,\n    sampler=None,\n    add_input=True,\n    data_augmentation_transforms=[compose_transforms])\nloaders = manager.get_dataloader(\n    train=True,\n    validation=False,\n    fold_index=0)\nfor dataitem in loaders.train:\n    print(\"-\" * 50)\n    print(dataitem.inputs.shape, dataitem.outputs.shape, dataitem.labels)\n    images = [dataitem.inputs[0, 0].numpy(), dataitem.inputs[0, 1].numpy(),\n              dataitem.outputs[0, 0].numpy(), dataitem.outputs[0, 1].numpy(),\n              dataitem.outputs[0, 4].numpy(), dataitem.outputs[0, 5].numpy()]\n    images = np.asarray(images)\n    images = np.expand_dims(images, axis=1)\n    images = images[..., images.shape[-1] // 2]\n    images = rescale(images, dynamic=(0, 255))\n    board.viewer.images(\n        images, opts={\"title\": \"transformer\", \"caption\": \"transformer\"},\n        win=\"transformer\")\n    time.sleep(2)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}