{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\npynet: class activation map\n===========================\n\nCredit: A Grigis\n\nBased on:\n\n- https://medium.com/@stepanulyanin/\n  implementing-grad-cam-in-pytorch-ea0937c31e82\n- http://snappishproductions.com/blog/2018/01/03/\n  class-activation-mapping-in-pytorch.html\n- https://github.com/jacobgil/pytorch-grad-cam\n\nA class activation map for a particular category indicates the discriminative\nimage regions used by the CNN to identify that category.It provides us with\na way to look into what particular parts of the image influenced the whole\nmodel's decision for a specifically assigned label.\nIt is particularly useful in analyzing wrongly classified samples.\n\nIt starts with finding the gradient of the most dominant logit with respect\nto the latest activation map in the model. We can interpret this as some\nencoded features that ended up activated in the final activation map\npersuaded the model as a whole to choose that particular logit\n(subsequently the corresponding class). The gradients are then pooled\nchannel-wise, and the activation channels are weighted with the corresponding\ngradients, yielding the collection of weighted activation channels. By\ninspecting these channels, we can tell which ones played the most significant\nrole in the decision of the class.\n\nThe main idea is to dissect the network as follows:\n\n- load the model\n- find its last convolutional layer\n- compute the most probable class\n- take the gradient of the class logit with respect to the activation maps we\n  have just obtained\n- pool the gradients\n- weight the channels of the map by the corresponding pooled gradients\n- interpolate the heat-map\n\nWe can compute the gradients in PyTorch, using the 'backward' method called on\na torch.Tensor. This is exactly what we are going to do: call 'backward()' on\nthe most probable logit, which we obtain by performing the forward pass of\nthe image through the network. However, PyTorch only caches the gradients of\nthe leaf nodes in the computational graph, such as weights, biases and other\nparameters. The gradients of the output with respect to the activations are\nmerely intermediate values and are discarded as soon as the gradient propagates\nthrough them on the way back. We will have to register the backward hook to\nthe activation map of the last convolutional layer in our model.\n\nLoad the data\n-------------\n\nLoad some images and apply the ImageNet transformation.\nYou may need to change the 'datasetdir' parameter.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pynet.datasets import DataManager, fetch_gradcam\nfrom pynet.plotting import plot_data\n\ndata = fetch_gradcam(\n    datasetdir=\"/tmp/gradcam\")\nmanager = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    number_of_folds=2,\n    batch_size=5,\n    test_size=1)\ndataset = manager[\"test\"]\nprint(dataset.inputs.shape)\nplot_data(dataset.inputs, nb_samples=5, random=False, rgb=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Explore different architectures\n-------------------------------\n\nLet's automate this procedure for different networks.\nWe need to reload the data for the inception network.\nYou may need to change the 'datasetdir' parameter.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\nfrom pynet.models.cam import get_cam_network\nfrom pynet.cam import GradCam\nimport matplotlib.pyplot as plt\n\ndata = fetch_gradcam(\n    datasetdir=\"/tmp/gradcam\")\nmanager1 = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    number_of_folds=2,\n    batch_size=1,\n    test_size=1)\nloaders1 = manager1.get_dataloader(test=True)\ndata = fetch_gradcam(\n    datasetdir=\"/tmp/gradcam\",\n    inception=True)\nmanager2 = DataManager(\n    input_path=data.input_path,\n    metadata_path=data.metadata_path,\n    number_of_folds=2,\n    batch_size=1,\n    test_size=1)\nloaders2 = manager2.get_dataloader(test=True)\n\nfor loaders, model_name in ((loaders1, \"vgg19\"),\n                            (loaders1, \"densenet201\"),\n                            (loaders1, \"resnet18\")):\n                            # (loaders2, \"inception_v3\")):\n\n    heatmaps = []\n    print(\"-\" * 10)\n    print(model_name)\n    for dataitem in loaders.test:\n        model, activation_layer_name = get_cam_network(model_name)\n        grad_cam = GradCam(model, [activation_layer_name], data.labels, top=1)\n        heatmaps.extend(grad_cam(dataitem.inputs).items())\n\n    fig, axs = plt.subplots(nrows=2, ncols=len(heatmaps))\n    fig.suptitle(model_name, fontsize=\"large\")\n    for cnt, (name, (img, arr, arr_highres)) in enumerate(heatmaps):\n        axs[0, cnt].set_title(name)\n        axs[0, cnt].matshow(arr)\n        axs[0, cnt].set_axis_off()\n        _img = img.data.numpy()[0].transpose((1, 2, 0))\n        axs[1, cnt].imshow(_img)\n        axs[1, cnt].imshow(arr_highres, alpha=0.6, cmap=\"jet\")\n        axs[1, cnt].set_axis_off()\n\nif \"CI_MODE\" not in os.environ:\n    plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}