Source code for mmaction.models.roi_extractors.single_straight3d
import torch
import torch.nn as nn
from mmaction.utils import import_module_error_class
try:
from mmcv.ops import RoIAlign, RoIPool
except (ImportError, ModuleNotFoundError):
@import_module_error_class('mmcv-full')
class RoIAlign(nn.Module):
pass
@import_module_error_class('mmcv-full')
class RoIPool(nn.Module):
pass
try:
from mmdet.models import ROI_EXTRACTORS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
mmdet_imported = False
[docs]class SingleRoIExtractor3D(nn.Module):
"""Extract RoI features from a single level feature map.
Args:
roi_layer_type (str): Specify the RoI layer type. Default: 'RoIAlign'.
featmap_stride (int): Strides of input feature maps. Default: 16.
output_size (int | tuple): Size or (Height, Width). Default: 16.
sampling_ratio (int): number of inputs samples to take for each
output sample. 0 to take samples densely for current models.
Default: 0.
pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
Default: 'avg'.
aligned (bool): if False, use the legacy implementation in
MMDetection. If True, align the results more perfectly.
Default: True.
with_temporal_pool (bool): if True, avgpool the temporal dim.
Default: True.
with_global (bool): if True, concatenate the RoI feature with global
feature. Default: False.
Note that sampling_ratio, pool_mode, aligned only apply when roi_layer_type
is set as RoIAlign.
"""
def __init__(self,
roi_layer_type='RoIAlign',
featmap_stride=16,
output_size=16,
sampling_ratio=0,
pool_mode='avg',
aligned=True,
with_temporal_pool=True,
with_global=False):
super().__init__()
self.roi_layer_type = roi_layer_type
assert self.roi_layer_type in ['RoIPool', 'RoIAlign']
self.featmap_stride = featmap_stride
self.spatial_scale = 1. / self.featmap_stride
self.output_size = output_size
self.sampling_ratio = sampling_ratio
self.pool_mode = pool_mode
self.aligned = aligned
self.with_temporal_pool = with_temporal_pool
self.with_global = with_global
if self.roi_layer_type == 'RoIPool':
self.roi_layer = RoIPool(self.output_size, self.spatial_scale)
else:
self.roi_layer = RoIAlign(
self.output_size,
self.spatial_scale,
sampling_ratio=self.sampling_ratio,
pool_mode=self.pool_mode,
aligned=self.aligned)
self.global_pool = nn.AdaptiveAvgPool2d(self.output_size)
def init_weights(self):
pass
# The shape of feat is N, C, T, H, W
[docs] def forward(self, feat, rois):
if not isinstance(feat, tuple):
feat = (feat, )
if len(feat) >= 2:
assert self.with_temporal_pool
if self.with_temporal_pool:
feat = [torch.mean(x, 2, keepdim=True) for x in feat]
feat = torch.cat(feat, axis=1)
roi_feats = []
for t in range(feat.size(2)):
frame_feat = feat[:, :, t].contiguous()
roi_feat = self.roi_layer(frame_feat, rois)
if self.with_global:
global_feat = self.global_pool(frame_feat.contiguous())
inds = rois[:, 0].type(torch.int64)
global_feat = global_feat[inds]
roi_feat = torch.cat([roi_feat, global_feat], dim=1)
roi_feat = roi_feat.contiguous()
roi_feats.append(roi_feat)
return torch.stack(roi_feats, dim=2)
if mmdet_imported:
ROI_EXTRACTORS.register_module()(SingleRoIExtractor3D)