import torch
from torch import nn
from ..registry import RECOGNIZERS
from .base import BaseRecognizer
[docs]@RECOGNIZERS.register_module()
class Recognizer3D(BaseRecognizer):
"""3D recognizer model framework."""
[docs] def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
losses = dict()
x = self.extract_feat(imgs)
if self.with_neck:
x, loss_aux = self.neck(x, labels.squeeze())
losses.update(loss_aux)
cls_score = self.cls_head(x)
gt_labels = labels.squeeze()
loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
losses.update(loss_cls)
return losses
def _do_test(self, imgs):
"""Defines the computation performed at every call when evaluation,
testing and gradcam."""
num_segs = imgs.shape[1]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
if self.max_testing_views is not None:
total_views = imgs.shape[0]
assert num_segs == total_views, (
'max_testing_views is only compatible '
'with batch_size == 1')
view_ptr = 0
cls_scores = []
while view_ptr < total_views:
batch_imgs = imgs[view_ptr:view_ptr + self.max_testing_views]
x = self.extract_feat(batch_imgs)
if self.with_neck:
x, _ = self.neck(x)
cls_score = self.cls_head(x)
cls_scores.append(cls_score)
view_ptr += self.max_testing_views
cls_score = torch.cat(cls_scores)
else:
x = self.extract_feat(imgs)
if self.with_neck:
x, _ = self.neck(x)
cls_score = self.cls_head(x)
cls_score = self.average_clip(cls_score, num_segs)
return cls_score
[docs] def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
return self._do_test(imgs).cpu().numpy()
[docs] def forward_dummy(self, imgs, softmax=False):
"""Used for computing network FLOPs.
See ``tools/analysis/get_flops.py``.
Args:
imgs (torch.Tensor): Input images.
Returns:
Tensor: Class score.
"""
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
x = self.extract_feat(imgs)
if self.with_neck:
x, _ = self.neck(x)
outs = self.cls_head(x)
if softmax:
outs = nn.functional.softmax(outs)
return (outs, )
[docs] def forward_gradcam(self, imgs):
"""Defines the computation performed at every call when using gradcam
utils."""
return self._do_test(imgs)