import torch
from torch import nn
from ..registry import RECOGNIZERS
from .base import BaseRecognizer
[docs]@RECOGNIZERS.register_module()
class Recognizer2D(BaseRecognizer):
"""2D recognizer model framework."""
[docs] def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
losses = dict()
x = self.extract_feat(imgs)
if self.backbone_from == 'torchvision':
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, loss_aux = self.neck(x, labels.squeeze())
x = x.squeeze(2)
num_segs = 1
losses.update(loss_aux)
cls_score = self.cls_head(x, num_segs)
gt_labels = labels.squeeze()
loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
losses.update(loss_cls)
return losses
def _do_test(self, imgs):
"""Defines the computation performed at every call when evaluation,
testing and gradcam."""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
x = self.extract_feat(imgs)
if self.backbone_from == 'torchvision':
if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
# apply adaptive avg pooling
x = nn.AdaptiveAvgPool2d(1)(x)
x = x.reshape((x.shape[0], -1))
x = x.reshape(x.shape + (1, 1))
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x)
x = x.squeeze(2)
num_segs = 1
# When using `TSNHead` or `TPNHead`, shape is [batch_size, num_classes]
# When using `TSMHead`, shape is [batch_size * num_crops, num_classes]
# `num_crops` is calculated by:
# 1) `twice_sample` in `SampleFrames`
# 2) `num_sample_positions` in `DenseSampleFrames`
# 3) `ThreeCrop/TenCrop/MultiGroupCrop` in `test_pipeline`
# 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1`
cls_score = self.cls_head(x, num_segs)
assert cls_score.size()[0] % batches == 0
# calculate num_crops automatically
cls_score = self.average_clip(cls_score,
cls_score.size()[0] // batches)
return cls_score
def _do_fcn_test(self, imgs):
# [N, num_crops * num_segs, C, H, W] ->
# [N * num_crops * num_segs, C, H, W]
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = self.test_cfg.get('num_segs', self.backbone.num_segments)
if self.test_cfg.get('flip', False):
imgs = torch.flip(imgs, [-1])
x = self.extract_feat(imgs)
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x)
else:
x = x.reshape((-1, num_segs) +
x.shape[1:]).transpose(1, 2).contiguous()
# When using `TSNHead` or `TPNHead`, shape is [batch_size, num_classes]
# When using `TSMHead`, shape is [batch_size * num_crops, num_classes]
# `num_crops` is calculated by:
# 1) `twice_sample` in `SampleFrames`
# 2) `num_sample_positions` in `DenseSampleFrames`
# 3) `ThreeCrop/TenCrop/MultiGroupCrop` in `test_pipeline`
# 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1`
cls_score = self.cls_head(x, fcn_test=True)
assert cls_score.size()[0] % batches == 0
# calculate num_crops automatically
cls_score = self.average_clip(cls_score,
cls_score.size()[0] // batches)
return cls_score
[docs] def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
if self.test_cfg.get('fcn_test', False):
# If specified, spatially fully-convolutional testing is performed
return self._do_fcn_test(imgs).cpu().numpy()
return self._do_test(imgs).cpu().numpy()
[docs] def forward_dummy(self, imgs, softmax=False):
"""Used for computing network FLOPs.
See ``tools/analysis/get_flops.py``.
Args:
imgs (torch.Tensor): Input images.
Returns:
Tensor: Class score.
"""
batches = imgs.shape[0]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches
x = self.extract_feat(imgs)
if self.with_neck:
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x)
x = x.squeeze(2)
num_segs = 1
outs = self.cls_head(x, num_segs)
if softmax:
outs = nn.functional.softmax(outs)
return (outs, )
[docs] def forward_gradcam(self, imgs):
"""Defines the computation performed at every call when using gradcam
utils."""
return self._do_test(imgs)