{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet optim helpers overview\n==============================\n\nCredit: A Grigis\n\npynet is a Python package related to deep learning and its application in\nMRI mediacal data analysis. It is accessible to everybody, and is reusable\nin various contexts. The project is hosted on github:\nhttps://github.com/neurospin/pynet.\n\nFirst checks\n------------\n\nIn order to test if the 'pynet' package is installed on your machine, you can\ncheck the package version.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pynet\nfrom pynet.utils import setup_logging\nsetup_logging(level=\"info\")\nprint(pynet.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now you can run the the configuration info function to see if all the\ndependencies are installed properly.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pynet.configure\nprint(pynet.configure.info())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Optimisation\n------------\n\nFirst load a dataset (the CIFAR10) and a network.\nYou may need to change the 'datasetdir' parameter.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport torch.nn as nn\nimport torch.nn.functional as func\nfrom pynet.datasets import DataManager, fetch_cifar\n\ndata = fetch_cifar(datasetdir=\"/tmp/cifar\")\nmanager = DataManager(\n    input_path=data.input_path,\n    labels=[\"label\"],\n    metadata_path=data.metadata_path,\n    number_of_folds=10,\n    batch_size=10,\n    stratify_label=\"category\",\n    test_size=0.1,\n    sample_size=(1 if \"CI_MODE\" not in os.environ else 0.01))\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super(Net, self).__init__()\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.pool = nn.MaxPool2d(2, 2)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        x = self.pool(func.relu(self.conv1(x)))\n        x = self.pool(func.relu(self.conv2(x)))\n        x = x.view(-1, 16 * 5 * 5)\n        x = func.relu(self.fc1(x))\n        x = func.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\n\nnet = Net()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now start the optimisation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom pynet.interfaces import DeepLearningInterface\n\ncl = DeepLearningInterface(\n    model=net,\n    optimizer_name=\"SGD\",\n    momentum=0.9,\n    learning_rate=0.001,\n    loss_name=\"CrossEntropyLoss\",\n    metrics=[\"accuracy\"])\nif \"CI_MODE\" not in os.environ:\n    from pynet.plotting import Board\n\n    def update_board(signal):\n        \"\"\" Callback to update visdom board visualizer.\n\n        Parameters\n        ----------\n        signal: SignalObject\n            an object with the trained model 'object', the emitted signal\n            'signal', the epoch number 'epoch' and the fold index 'fold'.\n        \"\"\"\n        net = signal.object.model\n        emitted_signal = signal.signal\n        epoch = signal.epoch\n        fold = signal.fold\n        data = {}\n        for key in signal.keys:\n            if key in (\"epoch\", \"fold\"):\n                continue\n            data[key] = getattr(signal, key)\n        board.update_plots(data)\n    board = Board(port=8097, host=\"http://localhost\", env=\"main\")\n    cl.add_observer(\"after_epoch\", update_board)\ntest_history, train_history = cl.training(\n    manager=manager,\n    nb_epochs=3,\n    checkpointdir=\"/tmp/pynet\",\n    fold_index=0,\n    with_validation=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "You can reload the optimization history at any time and any step.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pprint import pprint\nfrom pynet.history import History\nfrom pynet.plotting import plot_history\n\nhistory = History.load(\"/tmp/pynet/train_0_epoch_2.pkl\")\nprint(history)\nplot_history(history)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "And now predict the labels on the test set.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nfrom sklearn.metrics import classification_report\nfrom pynet.plotting import plot_data\n\ny_pred, X, y_true, loss, values = cl.testing(\n    manager=manager,\n    with_logit=True,\n    predict=True)\npprint(data.labels)\nprint(classification_report(y_true, y_pred, target_names=data.labels.values()))\ntitles = [\"{0}-{1}\".format(data.labels[it1], data.labels[it2])\n          for it1, it2 in zip(y_pred, y_true)]\nplot_data(X, labels=titles, nb_samples=5)\n\nif \"CI_MODE\" not in os.environ:\n    import matplotlib.pyplot as plt\n    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}