Helper Module for Deep Learning.
Source code for pynet.losses.segmentation
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2019 - 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Module that provides segmentation losses.
"""
# Third party import
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
from torch.distributions import Normal, kl_divergence
from torch.autograd import Variable
from pynet.utils import Losses
# Global parameters
logger = logging.getLogger("pynet")
[docs]@Losses.register
class FocalLoss(object):
""" Define a Focal Loss.
Loss(pt) = −αt mt (1−pt)γ log(pt)
where pt is the model's estimated probability for each class.
When an example is misclassified and pt is small, the modulating factor
is near 1 and the loss is unaffected. As pt goes to 1, the factor goes to
0 and the loss for well-classified examples is down-weighted.
The focusing parameter γ smoothly adjusts the rate at which easy examples
are down-weighted. When γ= 0, the loss is equivalent to cross entropy, and
as γ isincreased the effect of the modulating factor is likewise increased.
For instance, with γ= 2, an example classified with pt= 0.9 would have
100×lower loss compared with cross entropy and with pt≈0.968 it would have
1000×lower loss.
Then we use an α-balanced variant of the focal loss for addressing class
imbalance with a weighting factor α ∈ [0,1]. In practice α may be set by
inverse class frequency.
Reference: https://arxiv.org/abs/1708.02002
"""
[docs] def __init__(self, n_classes, gamma=2, alpha=None, reduction="mean",
with_logit=True):
""" Class instanciation.
Parameters
----------
n_classes: int
the number of classes.
gamma: float, default 2
the focusing parameter >=0.
alpha: float or list of float, default None
if set use alpha-balanced variant of the focal loss.
reduction: str, default 'mean'
specifies the reduction to apply to the output: 'none' - no
reduction will be applied, 'mean' - the sum of the output
will be divided by the number of elements in the output, 'sum'
- the output will be summed.
with_logit: bool, default True
apply the softmax logit function to the result.
"""
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.with_logit = with_logit
self.eps = 1e-9
alpha = alpha or 1
if not isinstance(alpha, list):
alpha = [alpha] * n_classes
if len(alpha) != n_classes:
raise ValueError("Invalid alphas size.")
logger.debug(" alpha: {0}".format(alpha))
self.alpha = torch.FloatTensor(alpha).view(-1, 1)
# self.alpha = self.alpha / self.alpha.sum()
self.debug("alpha", self.alpha)
def __call__(self, output, target):
""" Compute the loss.
Parameters
----------
output: Tensor (N,C,*)
predicted labels where C is the number of classes.
target: Tensor (N,*)
true labels where each value is 0≤target[i]≤C−1.
"""
logger.debug("Focal loss...")
self.debug("output", output)
self.debug("target", target)
if len(output.shape) < 2:
raise ValueError("Invalid labels shape {0}.".format(output.shape))
if output.shape[0] != target.shape[0]:
raise ValueError("Expected pred & true labels same batch size.")
if output.shape[2:] != target.shape[1:]:
raise ValueError("Expected pred & true labels same data size.")
if output.device != target.device:
raise ValueError("Pred & true labels must be in the same device.")
n_batch, n_classes = output.shape[:2]
device = output.device
dim = output.dim()
logger.debug(" n_batches: {0}".format(n_batch))
logger.debug(" n_classes: {0}".format(n_classes))
logger.debug(" dim: {0}".format(dim))
if self.with_logit:
output = func.softmax(output, dim=1)
logit = output + self.eps
self.debug("logit", logit)
# Reshape data
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
if dim > 2:
logit = logit.view(n_batch, n_classes, -1)
self.debug("logit", logit)
logit = logit.permute(0, 2, 1).contiguous()
self.debug("logit", logit)
logit = logit.view(-1, n_classes)
self.debug("logit", logit)
target = torch.squeeze(target, dim=1)
target = target.view(-1, 1)
self.debug("target", target)
# Create the labels one hot encoded tensor
idx = target.data
one_hot = torch.zeros(target.size(0), n_classes,
device=device, dtype=output.dtype)
target_one_hot = one_hot.scatter_(1, idx, 1.) + self.eps
# Compute the focal loss
if self.alpha.device != device:
self.alpha = self.alpha.to(device)
pt = torch.sum(target_one_hot * logit, dim=1)
self.debug("pt", pt)
logpt = torch.log(pt)
weight = torch.pow(1 - pt, self.gamma)
self.debug("weight", weight)
alpha = self.alpha[idx]
alpha = torch.squeeze(alpha)
self.debug("alpha", alpha)
loss = -1 * alpha * weight * logpt
self.debug("loss", loss)
# Reduction
if self.reduction == "none":
pass
elif self.reduction == "mean":
loss = torch.mean(loss) / self.alpha[target].mean()
elif self.reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError("Invalid reduction mode.")
logger.debug(" loss: {0}".format(loss))
return loss
def _forward_without_resizing(self, output, target):
""" Compute the loss.
Parameters
----------
output: Tensor (N,C,*)
predicted labels where C is the number of classes.
target: Tensor (N,*)
true labels where each value is 0≤target[i]≤C−1.
"""
logger.debug("Focal loss...")
self.debug("output", output)
self.debug("target", target)
if len(output.shape) < 2:
raise ValueError("Invalid labels shape {0}.".format(output.shape))
if output.shape[0] != target.shape[0]:
raise ValueError("Expected pred & true labels same batch size.")
if output.shape[2:] != target.shape[1:]:
raise ValueError("Expected pred & true labels same data size.")
if output.device != target.device:
raise ValueError("Pred & true labels must be in the same device.")
n_batch, n_classes = output.shape[:2]
device = output.device
dim = output.dim()
logger.debug(" n_batches: {0}".format(n_batch))
logger.debug(" n_classes: {0}".format(n_classes))
logger.debug(" dim: {0}".format(dim))
if self.with_logit:
output = func.softmax(output, dim=1)
logit = output + self.eps
self.debug("logit", logit)
# Create the labels one hot encoded tensor
one_hot = torch.zeros(n_batch, n_classes, *target.shape[1:],
device=device, dtype=output.dtype)
target_one_hot = one_hot.scatter_(
1, target.unsqueeze(1), 1.) + self.eps
# Compute the focal loss
if self.alpha.device != device:
self.alpha = self.alpha.to(device)
weight = torch.pow(1 - logit, self.gamma)
self.debug("weight", weight)
shape = [1, n_classes] + [1] * len(target.shape[1:])
alpha = self.alpha.view(*shape)
alpha = alpha.expand_as(weight)
self.debug("alpha", alpha)
focal = -1 * alpha * weight * torch.log(logit)
self.debug("focal", focal)
loss = torch.sum(target_one_hot * focal, dim=1)
self.debug("loss", loss)
# Reduction
if self.reduction == "none":
pass
elif self.reduction == "mean":
loss = torch.mean(loss) / self.alpha[target].mean()
elif self.reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError("Invalid reduction mode.")
logger.debug(" loss: {0}".format(loss))
return loss
[docs] def debug(self, name, tensor):
""" Print debug message.
Parameters
----------
name: str
the tensor name in the displayed message.
tensor: Tensor
a pytorch tensor.
"""
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]@Losses.register
class MaskLoss(object):
""" Define a Masked Loss.
Loss(pt) = −αt mt log(pt)
where pt is the model's estimated probability for each class.
"""
[docs] def __init__(self, n_classes, beta=0.2, alpha=None, reduction="mean",
with_logit=True):
""" Class instanciation.
Parameters
----------
n_classes: int
the number of classes.
beta: float, default 0.2
the minimum value in the mask.
alpha: float or list of float, default None
if set use alpha-balanced variant of the focal loss.
reduction: str, default 'mean'
specifies the reduction to apply to the output: 'none' - no
reduction will be applied, 'mean' - the sum of the output
will be divided by the number of elements in the output, 'sum'
- the output will be summed.
with_logit: bool, default True
apply the log softmax logit function to the result.
"""
self.beta = beta
self.alpha = alpha
self.reduction = reduction
self.with_logit = with_logit
self.eps = 1e-9
alpha = alpha or 1
if not isinstance(alpha, list):
alpha = [alpha] * n_classes
if len(alpha) != n_classes:
raise ValueError("Invalid alphas size.")
logger.debug(" alpha: {0}".format(alpha))
self.alpha = torch.FloatTensor(alpha)
self.debug("alpha", self.alpha)
def __call__(self, output, target, mask):
""" Compute the loss.
Parameters
----------
output: Tensor (N,C,*)
predicted labels where C is the number of classes.
target: Tensor (N,*)
true labels where each value is 0≤ target[i] ≤C−1.
mask: Tensor (N,*)
the binary mask used to mask the loss.
"""
logger.debug("Maked loss...")
self.debug("output", output)
self.debug("target", target)
self.debug("mask", mask)
if len(output.shape) < 2:
raise ValueError("Invalid labels shape {0}.".format(output.shape))
if output.shape[0] != target.shape[0]:
raise ValueError("Expected pred & true labels same batch size.")
if output.shape[2:] != target.shape[1:]:
raise ValueError("Expected pred & true labels same data size.")
if output.device != target.device:
raise ValueError("Pred & true labels must be in the same device.")
if mask is not None and output.shape[0] != mask.shape[0]:
raise ValueError("Expected pred & mask same batch size.")
if mask is not None and output.shape[2:] != mask.shape[1:]:
raise ValueError("Expected pred & mask same data size.")
if mask is not None and output.device != mask.device:
raise ValueError("Pred & mask must be in the same device.")
n_batch, n_classes = output.shape[:2]
device = output.device
logger.debug(" n_batches: {0}".format(n_batch))
logger.debug(" n_classes: {0}".format(n_classes))
if self.alpha.device != device:
self.alpha = self.alpha.to(device)
if self.with_logit:
output = func.log_softmax(output, dim=1)
logit = output + self.eps
self.debug("logit", logit)
# Compute the focal loss
mask[mask <= self.beta] = self.beta
loss = func.nll_loss(logit, target, weight=self.alpha,
reduction="none")
loss = loss * mask
self.debug("loss", loss)
# Reduction
if self.reduction == "none":
pass
elif self.reduction == "mean":
loss = torch.mean(loss) / self.alpha[target].mean()
elif self.reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError("Invalid reduction mode.")
logger.debug(" loss: {0}".format(loss))
return loss
[docs] def debug(self, name, tensor):
""" Print debug message.
Parameters
----------
name: str
the tensor name in the displayed message.
tensor: Tensor
a pytorch tensor.
"""
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]@Losses.register
class SoftDiceLoss(object):
""" Define a multi class Dice Loss.
Dice = (2 intersec Y) / (X + Y)
Note that PyTorch optimizers minimize a loss. In this case, we would like
to maximize the dice loss so we return 1 - Dice.
"""
[docs] def __init__(self, with_logit=True, reduction="mean"):
""" Class instanciation.
Parameters
----------
with_logit: bool, default True
apply the softmax logit function to the result.
reduction: str, default 'mean'
specifies the reduction to apply to the output: 'none' - no
reduction will be applied, 'mean' - the sum of the output
will be divided by the number of elements in the output, 'sum'
- the output will be summed.
"""
self.with_logit = with_logit
self.reduction = reduction
self.smooth = 1e-6
self.eps = 1e-6
def __call__(self, output, target):
""" Compute the loss.
Parameters
----------
output: Tensor (N,C,*)
predicted labels where C is the number of classes.
target: Tensor (N,*)
true labels where each value is 0≤targets[i]≤C−1.
"""
logger.debug("Dice loss...")
self.debug("output", output)
self.debug("target", target)
if len(output.shape) < 2:
raise ValueError("Invalid labels shape {0}.".format(output.shape))
if output.shape[0] != target.shape[0]:
raise ValueError("Expected pred & true labels same batch size.")
if output.shape[2:] != target.shape[1:]:
raise ValueError("Expected pred & true labels same data size.")
if output.device != target.device:
raise ValueError("Pred & true labels must be in the same device.")
n_batch, n_classes = output.shape[:2]
device = output.device
if self.with_logit:
prob = func.softmax(output, dim=1)
else:
prob = output
self.debug("logit", prob)
# Create the labels one hot encoded tensor
prob = prob.view(n_batch, -1)
dims = list(range(len(target.shape)))
dims.insert(1, len(target.shape))
dims = tuple(dims)
logger.debug("permute {0}".format(dims))
target_one_hot = func.one_hot(target, num_classes=n_classes)
self.debug("target_one_hot", target_one_hot)
target_one_hot = target_one_hot.permute(dims)
target_one_hot = target_one_hot.contiguous().view(n_batch, -1)
if target_one_hot.device != device:
target_one_hot = target_one_hot.to(device)
self.debug("target_one_hot", target_one_hot)
# Compute the dice score
intersection = prob * target_one_hot
self.debug("intersection", intersection)
dice_score = (2 * intersection.sum(dim=1) + self.smooth) / (
target_one_hot.sum(dim=1) + prob.sum(dim=1) + self.smooth)
loss = 1. - dice_score
self.debug("loss", loss)
# Reduction
if self.reduction == "none":
pass
elif self.reduction == "mean":
loss = torch.mean(loss)
elif self.reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError("Invalid reduction mode.")
logger.debug(" loss: {0}".format(loss))
return loss
def _forward_without_resizing(self, output, target):
""" Compute the loss.
Parameters
----------
output: Tensor (N,C,*)
predicted labels where C is the number of classes.
target: Tensor (N,*)
true labels where each value is 0≤targets[i]≤C−1.
"""
logger.debug("Dice loss...")
self.debug("output", output)
self.debug("target", target)
if len(output.shape) < 2:
raise ValueError("Invalid labels shape {0}.".format(output.shape))
if output.shape[0] != target.shape[0]:
raise ValueError("Expected pred & true labels same batch size.")
if output.shape[2:] != target.shape[1:]:
raise ValueError("Expected pred & true labels same data size.")
if output.device != target.device:
raise ValueError("Pred & true labels must be in the same device.")
n_batch, n_classes = output.shape[:2]
device = output.device
if self.with_logit:
prob = func.softmax(output, dim=1)
else:
prob = output
self.debug("logit", prob)
# Create the labels one hot encoded tensor
one_hot = torch.zeros(n_batch, n_classes, *target.shape[1:],
device=device, dtype=output.dtype)
target_one_hot = one_hot.scatter_(1, target.unsqueeze(1), 1.)
self.debug("one hot", target_one_hot)
# Compute the dice score
dims = tuple(range(1, len(target.shape) + 1))
intersection = torch.sum(prob * target_one_hot, dims)
self.debug("intersection", intersection)
cardinality = torch.sum(prob + target_one_hot, dims)
self.debug("cardinality", cardinality)
dice_score = 2. * intersection / (cardinality + self.eps)
loss = 1. - dice_score
self.debug("loss", loss)
# Reduction
if self.reduction == "none":
pass
elif self.reduction == "mean":
loss = torch.mean(loss)
elif self.reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError("Invalid reduction mode.")
logger.debug(" loss: {0}".format(loss))
return loss
[docs] def debug(self, name, tensor):
""" Print debug message.
Parameters
----------
name: str
the tensor name in the displayed message.
tensor: Tensor
a pytorch tensor.
"""
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
[docs]@Losses.register
class CustomKLLoss(object):
""" KL Loss.
"""
[docs] def __init__(self, reduction="mean"):
super(CustomKLLoss, self).__init__()
self.reduction = reduction
def __call__(self, posterior):
kl_loss = kl_divergence(posterior, Normal(0, 1)).sum(-1, keepdim=True)
if self.reduction == "mean":
return kl_loss.mean(0)
elif self.reduction == "sum":
return kl_loss.sum(0)
elif self.reduction == "none":
return kl_loss
else:
return NotImplementedError
[docs]@Losses.register
class NvNetCombinedLoss(object):
""" Combined Loss.
Cross Entropy loss + k1 * L2 loss + k2 * KL loss
Since the output of the segmentation decoder has N channels (prediction
for each tumor subregion), we simply add the N dice loss functions.
A hyper-parameter weight of k1=0.1, k2=0.1 was found empirically in the
paper.
"""
[docs] def __init__(self, num_classes, k1=0.1, k2=0.1):
super(NvNetCombinedLoss, self).__init__()
self.layer_outputs = None
self.num_classes = num_classes
self.k1 = k1
self.k2 = k2
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
self.l2_loss = nn.MSELoss(reduction="mean")
self.kl_loss = CustomKLLoss(reduction="mean")
def __call__(self, output, target):
logger.debug("NvNet Combined Loss...")
self.debug("output", output)
self.debug("target", target)
if self.layer_outputs is not None:
z = self.layer_outputs["z"]
posterior = self.layer_outputs["q"]
if len(output.shape) < 2:
raise ValueError("Invalid labels shape {0}.".format(output.shape))
if output.shape != target.shape:
raise ValueError("Expected pred & true of same size.")
if output.device != target.device:
raise ValueError("Pred & true labels must be in the same device.")
device = output.device
if self.layer_outputs is not None:
vae_pred = output[:, self.num_classes:]
vae_truth = target[:, self.num_classes:]
self.debug("vae_pred", vae_pred)
self.debug("vae_truth", vae_truth)
seg_pred = output[:, :self.num_classes]
seg_truth = target[:, :self.num_classes]
self.debug("seg_pred", seg_pred)
self.debug("seg_truth", seg_truth)
seg_truth = torch.argmax(seg_truth, dim=1).type(torch.LongTensor)
if seg_truth.device != device:
seg_truth = seg_truth.to(device)
self.debug("seg_truth", seg_truth)
ce_loss = self.ce_loss(seg_pred, seg_truth)
if self.layer_outputs is not None:
l2_loss = self.l2_loss(vae_pred, vae_truth)
kl_div = self.kl_loss(posterior)
combined_loss = ce_loss + self.k1 * l2_loss + self.k2 * kl_div
else:
l2_loss, kl_div = 0, 0
combined_loss = ce_loss
logger.debug(
"ce_loss: {0}, L2_loss: {1}, KL_div: {2}, combined_loss: "
"{3}".format(ce_loss, l2_loss, kl_div, combined_loss))
return combined_loss, {"l2_loss": l2_loss, "kl_loss": kl_div,
"ce_loss": ce_loss}
[docs] def debug(self, name, tensor):
""" Print debug message.
Parameters
----------
name: str
the tensor name in the displayed message.
tensor: Tensor
a pytorch tensor.
"""
logger.debug(" {3}: {0} - {1} - {2}".format(
tensor.shape, tensor.get_device(), tensor.dtype, name))
Follow us
© 2019, pynet developers .
Inspired by AZMIND template.
Inspired by AZMIND template.