Source code for mmaction.models.roi_extractors.single_straight3d

import torch
import torch.nn as nn

from mmaction.utils import import_module_error_class

try:
    from mmcv.ops import RoIAlign, RoIPool
except (ImportError, ModuleNotFoundError):

    @import_module_error_class('mmcv-full')
    class RoIAlign(nn.Module):
        pass

    @import_module_error_class('mmcv-full')
    class RoIPool(nn.Module):
        pass


try:
    from mmdet.models import ROI_EXTRACTORS
    mmdet_imported = True
except (ImportError, ModuleNotFoundError):
    mmdet_imported = False


[docs]class SingleRoIExtractor3D(nn.Module): """Extract RoI features from a single level feature map. Args: roi_layer_type (str): Specify the RoI layer type. Default: 'RoIAlign'. featmap_stride (int): Strides of input feature maps. Default: 16. output_size (int | tuple): Size or (Height, Width). Default: 16. sampling_ratio (int): number of inputs samples to take for each output sample. 0 to take samples densely for current models. Default: 0. pool_mode (str, 'avg' or 'max'): pooling mode in each bin. Default: 'avg'. aligned (bool): if False, use the legacy implementation in MMDetection. If True, align the results more perfectly. Default: True. with_temporal_pool (bool): if True, avgpool the temporal dim. Default: True. with_global (bool): if True, concatenate the RoI feature with global feature. Default: False. Note that sampling_ratio, pool_mode, aligned only apply when roi_layer_type is set as RoIAlign. """ def __init__(self, roi_layer_type='RoIAlign', featmap_stride=16, output_size=16, sampling_ratio=0, pool_mode='avg', aligned=True, with_temporal_pool=True, with_global=False): super().__init__() self.roi_layer_type = roi_layer_type assert self.roi_layer_type in ['RoIPool', 'RoIAlign'] self.featmap_stride = featmap_stride self.spatial_scale = 1. / self.featmap_stride self.output_size = output_size self.sampling_ratio = sampling_ratio self.pool_mode = pool_mode self.aligned = aligned self.with_temporal_pool = with_temporal_pool self.with_global = with_global if self.roi_layer_type == 'RoIPool': self.roi_layer = RoIPool(self.output_size, self.spatial_scale) else: self.roi_layer = RoIAlign( self.output_size, self.spatial_scale, sampling_ratio=self.sampling_ratio, pool_mode=self.pool_mode, aligned=self.aligned) self.global_pool = nn.AdaptiveAvgPool2d(self.output_size) def init_weights(self): pass # The shape of feat is N, C, T, H, W
[docs] def forward(self, feat, rois): if not isinstance(feat, tuple): feat = (feat, ) if len(feat) >= 2: assert self.with_temporal_pool if self.with_temporal_pool: feat = [torch.mean(x, 2, keepdim=True) for x in feat] feat = torch.cat(feat, axis=1) roi_feats = [] for t in range(feat.size(2)): frame_feat = feat[:, :, t].contiguous() roi_feat = self.roi_layer(frame_feat, rois) if self.with_global: global_feat = self.global_pool(frame_feat.contiguous()) inds = rois[:, 0].type(torch.int64) global_feat = global_feat[inds] roi_feat = torch.cat([roi_feat, global_feat], dim=1) roi_feat = roi_feat.contiguous() roi_feats.append(roi_feat) return torch.stack(roi_feats, dim=2)
if mmdet_imported: ROI_EXTRACTORS.register_module()(SingleRoIExtractor3D)