Source code for mmaction.models.heads.base

from abc import ABCMeta, abstractmethod

import torch
import torch.nn as nn

from ...core import top_k_accuracy
from ..builder import build_loss


class AvgConsensus(nn.Module):
    """Average consensus module.

    Args:
        dim (int): Decide which dim consensus function to apply.
            Default: 1.
    """

    def __init__(self, dim=1):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        """Defines the computation performed at every call."""
        return x.mean(dim=self.dim, keepdim=True)


[docs]class BaseHead(nn.Module, metaclass=ABCMeta): """Base class for head. All Head should subclass it. All subclass should overwrite: - Methods:``init_weights``, initializing weights in some modules. - Methods:``forward``, supporting to forward both for training and testing. Args: num_classes (int): Number of classes to be classified. in_channels (int): Number of channels in input feature. loss_cls (dict): Config for building loss. Default: dict(type='CrossEntropyLoss', loss_weight=1.0). multi_class (bool): Determines whether it is a multi-class recognition task. Default: False. label_smooth_eps (float): Epsilon used in label smooth. Reference: arxiv.org/abs/1906.02629. Default: 0. """ def __init__(self, num_classes, in_channels, loss_cls=dict(type='CrossEntropyLoss', loss_weight=1.0), multi_class=False, label_smooth_eps=0.0): super().__init__() self.num_classes = num_classes self.in_channels = in_channels self.loss_cls = build_loss(loss_cls) self.multi_class = multi_class self.label_smooth_eps = label_smooth_eps
[docs] @abstractmethod def init_weights(self): """Initiate the parameters either from existing checkpoint or from scratch."""
[docs] @abstractmethod def forward(self, x): """Defines the computation performed at every call."""
[docs] def loss(self, cls_score, labels, **kwargs): """Calculate the loss given output ``cls_score``, target ``labels``. Args: cls_score (torch.Tensor): The output of the model. labels (torch.Tensor): The target output of the model. Returns: dict: A dict containing field 'loss_cls'(mandatory) and 'top1_acc', 'top5_acc'(optional). """ losses = dict() if labels.shape == torch.Size([]): labels = labels.unsqueeze(0) elif labels.dim() == 1 and labels.size()[0] == self.num_classes \ and cls_score.size()[0] == 1: # Fix a bug when training with soft labels and batch size is 1. # When using soft labels, `labels` and `cls_socre` share the same # shape. labels = labels.unsqueeze(0) if not self.multi_class and cls_score.size() != labels.size(): top_k_acc = top_k_accuracy(cls_score.detach().cpu().numpy(), labels.detach().cpu().numpy(), (1, 5)) losses['top1_acc'] = torch.tensor( top_k_acc[0], device=cls_score.device) losses['top5_acc'] = torch.tensor( top_k_acc[1], device=cls_score.device) elif self.multi_class and self.label_smooth_eps != 0: labels = ((1 - self.label_smooth_eps) * labels + self.label_smooth_eps / self.num_classes) loss_cls = self.loss_cls(cls_score, labels, **kwargs) # loss_cls may be dictionary or single tensor if isinstance(loss_cls, dict): losses.update(loss_cls) else: losses['loss_cls'] = loss_cls return losses