{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet: transfer learning\n========================\n\nCredit: A Grigis\nBased on:\n- https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n\nIn this tutorial, you will learn how to train your network using transfer\nlearning on the multi modal orientation prediction dataset. Why? Because in\nmany cases we do not have enough data to train the network from scratch.\n\nWe will use a network trained on the imagenet dataset and freeze the weights\nfor a part/all of the network (except that of the final fully connected layer).\n\nRead the data\n-------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nimport sys\n\nif \"CI_MODE\" in os.environ:\n    sys.exit()\n\nfrom pynet.datasets import fetch_orientation\nfrom pynet.datasets import DataManager\nfrom skimage.color import gray2rgb\n\n\ndef prepare(arr):\n    arr = gray2rgb(arr.reshape((64, 64)))\n    arr = arr.transpose(2, 0, 1)\n    return arr\n\n\ndata = fetch_orientation(\n    datasetdir=\"/tmp/orientation\",\n    flatten=True)\nmanager = DataManager(\n    input_path=data.input_path,\n    labels=[\"label\"],\n    metadata_path=data.metadata_path,\n    number_of_folds=10,\n    batch_size=1000,\n    stratify_label=\"label\",\n    test_size=0.1,\n    sample_size=(0.1 if \"CI_MODE\" not in os.environ else 0.1),\n    input_transforms=[prepare])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Displaying some images of the test dataset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pynet.plotting import plot_data\nimport numpy as np\n\ndataset = manager[\"test\"]\nsample = dataset.inputs.reshape(-1, data.height, data.width)\nsample = np.expand_dims(sample, axis=1)\nplot_data(sample, nb_samples=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Load the model\n--------------\n\nLoad the model and fix all weights.\nChange the last linear layer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import pynet.interfaces as interfaces\nfrom pynet import NetParameters\nfrom pynet.utils import get_named_layers, freeze_layers, reset_weights\nimport torch.nn as nn\n\nnet_params = NetParameters(\n    num_classes=1000)\ncl = interfaces.ResNet18Classifier(\n    net_params,\n    pretrained=\"/neurospin/nsap/torch/models/resnet18-5c106cde.pth\",\n    optimizer_name=\"Adam\",\n    learning_rate=1e-4,\n    loss_name=\"NLLLoss\",\n    metrics=[\"accuracy\"])\nprint(cl.model)\nlayers = get_named_layers(cl.model, allowed_layers=[nn.Module], resume=True)\nprint(layers.keys())\nto_freeze_layers = [\n    \"conv1\", \"bn1\", \"relu\", \"maxpool\", \"layer1\", \"layer2\", \"layer3\",\n    \"layer4\"]\nfreeze_layers(cl.model, to_freeze_layers)\nnb_features = cl.model.fc.in_features\ncl.model.fc = nn.Linear(nb_features, 9)\nprint(cl.model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Retrain the model\n-----------------\n\nTrain the model\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pynet.plotting import plot_history\n\n\ndef train(cl, dataset):\n\n    state = dict(\n        (key, val)\n        for key, val in cl.model.state_dict().items()\n        if key.endswith(\".weight\"))\n    test_history, train_history = cl.training(\n        manager=manager,\n        nb_epochs=5,\n        checkpointdir=None,\n        fold_index=0,\n        with_validation=False)\n    train_state = dict(\n        (key, val)\n        for key, val in cl.model.state_dict().items()\n        if key.endswith(\".weight\"))\n    for key, val in state.items():\n        if not np.allclose(val, train_state[key]):\n            print(\"--\", key)\n\n    idx = 0\n    y_pred_prob, X, y_true, loss, values = cl.testing(\n        manager=manager,\n        with_logit=True,\n        predict=False)\n    y_pred = np.argmax(y_pred_prob, axis=1)\n    print(\" ** true label      : \", y_true[idx])\n    print(\" ** predicted label : \", y_pred[idx])\n    titles = [\"{0}-{1}\".format(data.labels[it1], data.labels[it2])\n              for it1, it2 in zip(y_pred, y_true)]\n    plot_data(X, labels=titles, nb_samples=5)\n    plot_history(train_history)\n\n\ntrain(cl, dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Test different strategies\n-------------------------\n\nOK it's not working, let's try different transfer learning strategies.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cl = interfaces.ResNet18Classifier(\n    net_params,\n    pretrained=\"/neurospin/nsap/torch/models/resnet18-5c106cde.pth\",\n    optimizer_name=\"Adam\",\n    learning_rate=1e-4,\n    loss_name=\"NLLLoss\",\n    metrics=[\"accuracy\"])\nto_freeze_layers = [\"conv1\", \"bn1\", \"relu\", \"maxpool\", \"layer1\", \"layer2\"]\nfreeze_layers(cl.model, to_freeze_layers)\nnb_features = cl.model.fc.in_features\ncl.model.fc = nn.Linear(nb_features, 9)\ntrain(cl, dataset)\n\ncl = interfaces.ResNet18Classifier(\n    net_params,\n    pretrained=\"/neurospin/nsap/torch/models/resnet18-5c106cde.pth\",\n    optimizer_name=\"Adam\",\n    learning_rate=1e-4,\n    loss_name=\"NLLLoss\",\n    metrics=[\"accuracy\"])\nreset_weights(cl.model)\nnb_features = cl.model.fc.in_features\ncl.model.fc = nn.Linear(nb_features, 9)\ntrain(cl, dataset)\n\nif \"CI_MODE\" not in os.environ:\n    import matplotlib.pyplot as plt\n    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}