import torch
import torch.nn.functional as F
from ..registry import LOSSES
from .base import BaseWeightedLoss
[docs]@LOSSES.register_module()
class CrossEntropyLoss(BaseWeightedLoss):
"""Cross Entropy Loss.
Support two kinds of labels and their corresponding loss type. It's worth
mentioning that loss type will be detected by the shape of ``cls_score``
and ``label``.
1) Hard label: This label is an integer array and all of the elements are
in the range [0, num_classes - 1]. This label's shape should be
``cls_score``'s shape with the `num_classes` dimension removed.
2) Soft label(probablity distribution over classes): This label is a
probability distribution and all of the elements are in the range
[0, 1]. This label's shape must be the same as ``cls_score``. For now,
only 2-dim soft label is supported.
Args:
loss_weight (float): Factor scalar multiplied on the loss.
Default: 1.0.
class_weight (list[float] | None): Loss weight for each class. If set
as None, use the same weight 1 for all classes. Only applies
to CrossEntropyLoss and BCELossWithLogits (should not be set when
using other losses). Default: None.
"""
def __init__(self, loss_weight=1.0, class_weight=None):
super().__init__(loss_weight=loss_weight)
self.class_weight = None
if class_weight is not None:
self.class_weight = torch.Tensor(class_weight)
def _forward(self, cls_score, label, **kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The class score.
label (torch.Tensor): The ground truth label.
kwargs: Any keyword argument to be used to calculate
CrossEntropy loss.
Returns:
torch.Tensor: The returned CrossEntropy loss.
"""
if cls_score.size() == label.size():
# calculate loss for soft label
assert cls_score.dim() == 2, 'Only support 2-dim soft label'
assert len(kwargs) == 0, \
('For now, no extra args are supported for soft label, '
f'but get {kwargs}')
lsm = F.log_softmax(cls_score, 1)
if self.class_weight is not None:
lsm = lsm * self.class_weight.unsqueeze(0)
loss_cls = -(label * lsm).sum(1)
# default reduction 'mean'
if self.class_weight is not None:
# Use weighted average as pytorch CrossEntropyLoss does.
# For more information, please visit https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html # noqa
loss_cls = loss_cls.sum() / torch.sum(
self.class_weight.unsqueeze(0) * label)
else:
loss_cls = loss_cls.mean()
else:
# calculate loss for hard label
if self.class_weight is not None:
assert 'weight' not in kwargs, \
"The key 'weight' already exists."
kwargs['weight'] = self.class_weight.to(cls_score.device)
loss_cls = F.cross_entropy(cls_score, label, **kwargs)
return loss_cls
[docs]@LOSSES.register_module()
class BCELossWithLogits(BaseWeightedLoss):
"""Binary Cross Entropy Loss with logits.
Args:
loss_weight (float): Factor scalar multiplied on the loss.
Default: 1.0.
class_weight (list[float] | None): Loss weight for each class. If set
as None, use the same weight 1 for all classes. Only applies
to CrossEntropyLoss and BCELossWithLogits (should not be set when
using other losses). Default: None.
"""
def __init__(self, loss_weight=1.0, class_weight=None):
super().__init__(loss_weight=loss_weight)
self.class_weight = None
if class_weight is not None:
self.class_weight = torch.Tensor(class_weight)
def _forward(self, cls_score, label, **kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The class score.
label (torch.Tensor): The ground truth label.
kwargs: Any keyword argument to be used to calculate
bce loss with logits.
Returns:
torch.Tensor: The returned bce loss with logits.
"""
if self.class_weight is not None:
assert 'weight' not in kwargs, "The key 'weight' already exists."
kwargs['weight'] = self.class_weight.to(cls_score.device)
loss_cls = F.binary_cross_entropy_with_logits(cls_score, label,
**kwargs)
return loss_cls