import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import auto_fp16
from .. import builder
[docs]class BaseRecognizer(nn.Module, metaclass=ABCMeta):
"""Base class for recognizers.
All recognizers should subclass it.
All subclass should overwrite:
- Methods:``forward_train``, supporting to forward when training.
- Methods:``forward_test``, supporting to forward when testing.
Args:
backbone (dict): Backbone modules to extract feature.
cls_head (dict): Classification head to process feature.
train_cfg (dict | None): Config for training. Default: None.
test_cfg (dict | None): Config for testing. Default: None.
"""
def __init__(self,
backbone,
cls_head,
neck=None,
train_cfg=None,
test_cfg=None):
super().__init__()
# record the source of the backbone
self.backbone_from = 'mmaction2'
if backbone['type'].startswith('mmcls.'):
try:
import mmcls.models.builder as mmcls_builder
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install mmcls to use this backbone.')
backbone['type'] = backbone['type'][6:]
self.backbone = mmcls_builder.build_backbone(backbone)
self.backbone_from = 'mmcls'
elif backbone['type'].startswith('torchvision.'):
try:
import torchvision.models
except (ImportError, ModuleNotFoundError):
raise ImportError('Please install torchvision to use this '
'backbone.')
backbone_type = backbone.pop('type')[12:]
self.backbone = torchvision.models.__dict__[backbone_type](
**backbone)
# disable the classifier
self.backbone.classifier = nn.Identity()
self.backbone.fc = nn.Identity()
self.backbone_from = 'torchvision'
else:
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
self.cls_head = builder.build_head(cls_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
# aux_info is the list of tensor names beyond 'imgs' and 'label' which
# will be used in train_step and val_step, data_batch should contain
# these tensors
self.aux_info = []
if train_cfg is not None and 'aux_info' in train_cfg:
self.aux_info = train_cfg['aux_info']
# max_testing_views should be int
self.max_testing_views = None
if test_cfg is not None and 'max_testing_views' in test_cfg:
self.max_testing_views = test_cfg['max_testing_views']
assert isinstance(self.max_testing_views, int)
# mini-batch blending, e.g. mixup, cutmix, etc.
self.blending = None
if train_cfg is not None and 'blending' in train_cfg:
from mmcv.utils import build_from_cfg
from ...datasets.registry import BLENDINGS
self.blending = build_from_cfg(train_cfg['blending'], BLENDINGS)
self.init_weights()
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None
[docs] def init_weights(self):
"""Initialize the model network weights."""
if self.backbone_from in ['mmcls', 'mmaction2']:
self.backbone.init_weights()
elif self.backbone_from == 'torchvision':
warnings.warn('We do not initialize weights for backbones in '
'torchvision, since the weights for backbones in '
'torchvision are initialized in their __init__ '
'functions. ')
else:
raise NotImplementedError('Unsupported backbone source '
f'{self.backbone_from}!')
self.cls_head.init_weights()
if self.with_neck:
self.neck.init_weights()
[docs] def average_clip(self, cls_score, num_segs=1):
"""Averaging class score over multiple clips.
Using different averaging types ('score' or 'prob' or None,
which defined in test_cfg) to computed the final averaged
class score. Only called in test mode.
Args:
cls_score (torch.Tensor): Class score to be averaged.
num_segs (int): Number of clips for each input sample.
Returns:
torch.Tensor: Averaged class score.
"""
if 'average_clips' not in self.test_cfg.keys():
raise KeyError('"average_clips" must defined in test_cfg\'s keys')
average_clips = self.test_cfg['average_clips']
if average_clips not in ['score', 'prob', None]:
raise ValueError(f'{average_clips} is not supported. '
f'Currently supported ones are '
f'["score", "prob", None]')
if average_clips is None:
return cls_score
batch_size = cls_score.shape[0]
cls_score = cls_score.view(batch_size // num_segs, num_segs, -1)
if average_clips == 'prob':
cls_score = F.softmax(cls_score, dim=2).mean(dim=1)
elif average_clips == 'score':
cls_score = cls_score.mean(dim=1)
return cls_score
[docs] @abstractmethod
def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
[docs] @abstractmethod
def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
[docs] @abstractmethod
def forward_gradcam(self, imgs):
"""Defines the computation performed at every all when using gradcam
utils."""
@staticmethod
def _parse_losses(losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
[docs] def forward(self, imgs, label=None, return_loss=True, **kwargs):
"""Define the computation performed at every call."""
if kwargs.get('gradcam', False):
del kwargs['gradcam']
return self.forward_gradcam(imgs, **kwargs)
if return_loss:
if label is None:
raise ValueError('Label should not be None.')
if self.blending is not None:
imgs, label = self.blending(imgs, label)
return self.forward_train(imgs, label, **kwargs)
return self.forward_test(imgs, **kwargs)
[docs] def train_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data_batch (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
imgs = data_batch['imgs']
label = data_batch['label']
aux_info = {}
for item in self.aux_info:
assert item in data_batch
aux_info[item] = data_batch[item]
losses = self(imgs, label, return_loss=True, **aux_info)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
log_vars=log_vars,
num_samples=len(next(iter(data_batch.values()))))
return outputs
[docs] def val_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
imgs = data_batch['imgs']
label = data_batch['label']
aux_info = {}
for item in self.aux_info:
aux_info[item] = data_batch[item]
losses = self(imgs, label, return_loss=True, **aux_info)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
log_vars=log_vars,
num_samples=len(next(iter(data_batch.values()))))
return outputs