Helper Module for Deep Learning.
Source code for pynet.cam
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2019
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Module that provides tools to compute class activation map.
"""
# Imports
import logging
import skimage
import skimage.transform
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as func
# Global parameters
logger = logging.getLogger("pynet")
[docs]class FeatureExtractor(object):
""" Class for extracting activations and registering gradients from
targetted intermediate layers.
"""
[docs] def __init__(self, model, target_layers):
self.model = model
self.target_layers = target_layers
self.gradients = []
def __call__(self, x):
outputs = []
self.gradients = []
for name, module in self.model._modules.items():
x = module(x)
if name in self.target_layers:
x.register_hook(self.save_gradient)
outputs += [x]
return outputs, x
[docs]class ModelOutputs(object):
""" Class for making a forward pass, and getting:
1- the network output.
2- activations from intermeddiate targetted layers.
3- gradients from intermeddiate targetted layers.
"""
[docs] def __init__(self, model, target_layers):
self.model = model
self.feature_extractor = FeatureExtractor(
self.model.features, target_layers)
def __call__(self, x):
if hasattr(self.model, "pre"):
x = self.model.pre(x)
target_activations, output = self.feature_extractor(x)
if hasattr(self.model, "pool"):
output = self.model.pool(output)
output = output.view(output.size(0), -1)
output = self.model.classifier(output)
return target_activations, output
[docs]class GradCam(object):
""" Class for computing class activation map.
"""
[docs] def __init__(self, model, target_layers, labels, top=1):
self.model = model
self.labels = labels
self.top = top
self.model.eval()
self.extractor = ModelOutputs(self.model, target_layers)
def __call__(self, input):
features, output = self.extractor(input)
pred_prob = func.softmax(output, dim=1).data.squeeze()
probs, indices = pred_prob.sort(0, True)
probs = probs.data.numpy()
indices = indices.data.numpy()
heatmaps = {}
for cnt, (prob, index) in enumerate(zip(probs, indices)):
if cnt == self.top:
break
label = self.labels[str(index)][1]
line = "{0:.3f} -> {1}".format(prob, label)
logger.info(line)
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
one_hot[0][index] = 1
one_hot = Variable(torch.from_numpy(one_hot), requires_grad=True)
one_hot = torch.sum(one_hot * output)
self.model.features.zero_grad()
self.model.classifier.zero_grad()
one_hot.backward(retain_graph=True)
gradients = self.extractor.get_activations_gradient()[-1]
gradients = gradients.cpu().data.numpy()
pooled_gradients = np.mean(gradients, axis=(0, 2, 3))
activations = features[-1]
activations = activations.cpu().data.numpy()
for cnt, weight in enumerate(pooled_gradients):
activations[:, cnt] *= weight
heatmap = np.mean(activations, axis=1).squeeze()
heatmap = np.maximum(heatmap, 0)
heatmap -= np.min(heatmap)
heatmap /= np.max(heatmap)
heatmap_highres = skimage.transform.resize(
heatmap, input.shape[2:])
heatmaps[label] = (input, heatmap, heatmap_highres)
return heatmaps
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.