Helper Module for Deep Learning.
pynet: class activation map¶
Credit: A Grigis
Based on:
https://medium.com/@stepanulyanin/ implementing-grad-cam-in-pytorch-ea0937c31e82
http://snappishproductions.com/blog/2018/01/03/ class-activation-mapping-in-pytorch.html
A class activation map for a particular category indicates the discriminative image regions used by the CNN to identify that category.It provides us with a way to look into what particular parts of the image influenced the whole model’s decision for a specifically assigned label. It is particularly useful in analyzing wrongly classified samples.
It starts with finding the gradient of the most dominant logit with respect to the latest activation map in the model. We can interpret this as some encoded features that ended up activated in the final activation map persuaded the model as a whole to choose that particular logit (subsequently the corresponding class). The gradients are then pooled channel-wise, and the activation channels are weighted with the corresponding gradients, yielding the collection of weighted activation channels. By inspecting these channels, we can tell which ones played the most significant role in the decision of the class.
The main idea is to dissect the network as follows:
load the model
find its last convolutional layer
compute the most probable class
take the gradient of the class logit with respect to the activation maps we have just obtained
pool the gradients
weight the channels of the map by the corresponding pooled gradients
interpolate the heat-map
We can compute the gradients in PyTorch, using the ‘backward’ method called on a torch.Tensor. This is exactly what we are going to do: call ‘backward()’ on the most probable logit, which we obtain by performing the forward pass of the image through the network. However, PyTorch only caches the gradients of the leaf nodes in the computational graph, such as weights, biases and other parameters. The gradients of the output with respect to the activations are merely intermediate values and are discarded as soon as the gradient propagates through them on the way back. We will have to register the backward hook to the activation map of the last convolutional layer in our model.
Load the data¶
Load some images and apply the ImageNet transformation. You may need to change the ‘datasetdir’ parameter.
from pynet.datasets import DataManager, fetch_gradcam
from pynet.plotting import plot_data
data = fetch_gradcam(
datasetdir="/tmp/gradcam")
manager = DataManager(
input_path=data.input_path,
metadata_path=data.metadata_path,
number_of_folds=2,
batch_size=5,
test_size=1)
dataset = manager["test"]
print(dataset.inputs.shape)
plot_data(dataset.inputs, nb_samples=5, random=False, rgb=True)
Explore different architectures¶
Let’s automate this procedure for different networks. We need to reload the data for the inception network. You may need to change the ‘datasetdir’ parameter.
import os
from pynet.models.cam import get_cam_network
from pynet.cam import GradCam
import matplotlib.pyplot as plt
data = fetch_gradcam(
datasetdir="/tmp/gradcam")
manager1 = DataManager(
input_path=data.input_path,
metadata_path=data.metadata_path,
number_of_folds=2,
batch_size=1,
test_size=1)
loaders1 = manager1.get_dataloader(test=True)
data = fetch_gradcam(
datasetdir="/tmp/gradcam",
inception=True)
manager2 = DataManager(
input_path=data.input_path,
metadata_path=data.metadata_path,
number_of_folds=2,
batch_size=1,
test_size=1)
loaders2 = manager2.get_dataloader(test=True)
for loaders, model_name in ((loaders1, "vgg19"),
(loaders1, "densenet201"),
(loaders1, "resnet18")):
# (loaders2, "inception_v3")):
heatmaps = []
print("-" * 10)
print(model_name)
for dataitem in loaders.test:
model, activation_layer_name = get_cam_network(model_name)
grad_cam = GradCam(model, [activation_layer_name], data.labels, top=1)
heatmaps.extend(grad_cam(dataitem.inputs).items())
fig, axs = plt.subplots(nrows=2, ncols=len(heatmaps))
fig.suptitle(model_name, fontsize="large")
for cnt, (name, (img, arr, arr_highres)) in enumerate(heatmaps):
axs[0, cnt].set_title(name)
axs[0, cnt].matshow(arr)
axs[0, cnt].set_axis_off()
_img = img.data.numpy()[0].transpose((1, 2, 0))
axs[1, cnt].imshow(_img)
axs[1, cnt].imshow(arr_highres, alpha=0.6, cmap="jet")
axs[1, cnt].set_axis_off()
if "CI_MODE" not in os.environ:
plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)
Gallery generated by Sphinx-Gallery
Follow us
Inspired by AZMIND template.