import copy
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.utils import _BatchNorm
from mmaction.models.common import LFB
from mmaction.utils import get_root_logger
try:
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
mmdet_imported = False
class NonLocalLayer(nn.Module):
"""Non-local layer used in `FBONonLocal` is a variation of the vanilla non-
local block.
Args:
st_feat_channels (int): Channels of short-term features.
lt_feat_channels (int): Channels of long-term features.
latent_channels (int): Channels of latent features.
use_scale (bool): Whether to scale pairwise_weight by
`1/sqrt(latent_channels)`. Default: True.
pre_activate (bool): Whether to use the activation function before
upsampling. Default: False.
conv_cfg (Dict | None): The config dict for convolution layers. If
not specified, it will use `nn.Conv2d` for convolution layers.
Default: None.
norm_cfg (Dict | None): he config dict for normalization layers.
Default: None.
dropout_ratio (float, optional): Probability of dropout layer.
Default: 0.2.
zero_init_out_conv (bool): Whether to use zero initialization for
out_conv. Default: False.
"""
def __init__(self,
st_feat_channels,
lt_feat_channels,
latent_channels,
num_st_feat,
num_lt_feat,
use_scale=True,
pre_activate=True,
pre_activate_with_ln=True,
conv_cfg=None,
norm_cfg=None,
dropout_ratio=0.2,
zero_init_out_conv=False):
super().__init__()
if conv_cfg is None:
conv_cfg = dict(type='Conv3d')
self.st_feat_channels = st_feat_channels
self.lt_feat_channels = lt_feat_channels
self.latent_channels = latent_channels
self.num_st_feat = num_st_feat
self.num_lt_feat = num_lt_feat
self.use_scale = use_scale
self.pre_activate = pre_activate
self.pre_activate_with_ln = pre_activate_with_ln
self.dropout_ratio = dropout_ratio
self.zero_init_out_conv = zero_init_out_conv
self.st_feat_conv = ConvModule(
self.st_feat_channels,
self.latent_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.lt_feat_conv = ConvModule(
self.lt_feat_channels,
self.latent_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.global_conv = ConvModule(
self.lt_feat_channels,
self.latent_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
if pre_activate:
self.ln = nn.LayerNorm([latent_channels, num_st_feat, 1, 1])
else:
self.ln = nn.LayerNorm([st_feat_channels, num_st_feat, 1, 1])
self.relu = nn.ReLU()
self.out_conv = ConvModule(
self.latent_channels,
self.st_feat_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
if self.dropout_ratio > 0:
self.dropout = nn.Dropout(self.dropout_ratio)
def init_weights(self, pretrained=None):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
if isinstance(pretrained, str):
logger = get_root_logger()
logger.info(f'load model from: {pretrained}')
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv3d):
kaiming_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
if self.zero_init_out_conv:
constant_init(self.out_conv, 0, bias=0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, st_feat, lt_feat):
n, c = st_feat.size(0), self.latent_channels
num_st_feat, num_lt_feat = self.num_st_feat, self.num_lt_feat
theta = self.st_feat_conv(st_feat)
theta = theta.view(n, c, num_st_feat)
phi = self.lt_feat_conv(lt_feat)
phi = phi.view(n, c, num_lt_feat)
g = self.global_conv(lt_feat)
g = g.view(n, c, num_lt_feat)
# (n, num_st_feat, c), (n, c, num_lt_feat)
# -> (n, num_st_feat, num_lt_feat)
theta_phi = torch.matmul(theta.permute(0, 2, 1), phi)
if self.use_scale:
theta_phi /= c**0.5
p = theta_phi.softmax(dim=-1)
# (n, c, num_lt_feat), (n, num_lt_feat, num_st_feat)
# -> (n, c, num_st_feat, 1, 1)
out = torch.matmul(g, p.permute(0, 2, 1)).view(n, c, num_st_feat, 1, 1)
# If need to activate it before out_conv, use relu here, otherwise
# use relu outside the non local layer.
if self.pre_activate:
if self.pre_activate_with_ln:
out = self.ln(out)
out = self.relu(out)
out = self.out_conv(out)
if not self.pre_activate:
out = self.ln(out)
if self.dropout_ratio > 0:
out = self.dropout(out)
return out
class FBONonLocal(nn.Module):
"""Non local feature bank operator.
Args:
st_feat_channels (int): Channels of short-term features.
lt_feat_channels (int): Channels of long-term features.
latent_channels (int): Channles of latent features.
num_st_feat (int): Number of short-term roi features.
num_lt_feat (int): Number of long-term roi features.
num_non_local_layers (int): Number of non-local layers, which is
at least 1. Default: 2.
st_feat_dropout_ratio (float): Probability of dropout layer for
short-term features. Default: 0.2.
lt_feat_dropout_ratio (float): Probability of dropout layer for
long-term features. Default: 0.2.
pre_activate (bool): Whether to use the activation function before
upsampling in non local layers. Default: True.
zero_init_out_conv (bool): Whether to use zero initialization for
out_conv in NonLocalLayer. Default: False.
"""
def __init__(self,
st_feat_channels,
lt_feat_channels,
latent_channels,
num_st_feat,
num_lt_feat,
num_non_local_layers=2,
st_feat_dropout_ratio=0.2,
lt_feat_dropout_ratio=0.2,
pre_activate=True,
zero_init_out_conv=False):
super().__init__()
assert num_non_local_layers >= 1, (
'At least one non_local_layer is needed.')
self.st_feat_channels = st_feat_channels
self.lt_feat_channels = lt_feat_channels
self.latent_channels = latent_channels
self.num_st_feat = num_st_feat
self.num_lt_feat = num_lt_feat
self.num_non_local_layers = num_non_local_layers
self.st_feat_dropout_ratio = st_feat_dropout_ratio
self.lt_feat_dropout_ratio = lt_feat_dropout_ratio
self.pre_activate = pre_activate
self.zero_init_out_conv = zero_init_out_conv
self.st_feat_conv = nn.Conv3d(
st_feat_channels, latent_channels, kernel_size=1)
self.lt_feat_conv = nn.Conv3d(
lt_feat_channels, latent_channels, kernel_size=1)
if self.st_feat_dropout_ratio > 0:
self.st_feat_dropout = nn.Dropout(self.st_feat_dropout_ratio)
if self.lt_feat_dropout_ratio > 0:
self.lt_feat_dropout = nn.Dropout(self.lt_feat_dropout_ratio)
if not self.pre_activate:
self.relu = nn.ReLU()
self.non_local_layers = []
for idx in range(self.num_non_local_layers):
layer_name = f'non_local_layer_{idx + 1}'
self.add_module(
layer_name,
NonLocalLayer(
latent_channels,
latent_channels,
latent_channels,
num_st_feat,
num_lt_feat,
pre_activate=self.pre_activate,
zero_init_out_conv=self.zero_init_out_conv))
self.non_local_layers.append(layer_name)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
kaiming_init(self.st_feat_conv)
kaiming_init(self.lt_feat_conv)
for layer_name in self.non_local_layers:
non_local_layer = getattr(self, layer_name)
non_local_layer.init_weights(pretrained=pretrained)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, st_feat, lt_feat):
# prepare st_feat
st_feat = self.st_feat_conv(st_feat)
if self.st_feat_dropout_ratio > 0:
st_feat = self.st_feat_dropout(st_feat)
# prepare lt_feat
lt_feat = self.lt_feat_conv(lt_feat)
if self.lt_feat_dropout_ratio > 0:
lt_feat = self.lt_feat_dropout(lt_feat)
# fuse short-term and long-term features in NonLocal Layer
for layer_name in self.non_local_layers:
identity = st_feat
non_local_layer = getattr(self, layer_name)
nl_out = non_local_layer(st_feat, lt_feat)
nl_out = identity + nl_out
if not self.pre_activate:
nl_out = self.relu(nl_out)
st_feat = nl_out
return nl_out
class FBOAvg(nn.Module):
"""Avg pool feature bank operator."""
def __init__(self):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool3d((1, None, None))
def init_weights(self, pretrained=None):
# FBOAvg has no parameters to be initalized.
pass
def forward(self, st_feat, lt_feat):
out = self.avg_pool(lt_feat)
return out
class FBOMax(nn.Module):
"""Max pool feature bank operator."""
def __init__(self):
super().__init__()
self.max_pool = nn.AdaptiveMaxPool3d((1, None, None))
def init_weights(self, pretrained=None):
# FBOMax has no parameters to be initialized.
pass
def forward(self, st_feat, lt_feat):
out = self.max_pool(lt_feat)
return out
[docs]class FBOHead(nn.Module):
"""Feature Bank Operator Head.
Add feature bank operator for the spatiotemporal detection model to fuse
short-term features and long-term features.
Args:
lfb_cfg (Dict): The config dict for LFB which is used to sample
long-term features.
fbo_cfg (Dict): The config dict for feature bank operator (FBO). The
type of fbo is also in the config dict and supported fbo type is
`fbo_dict`.
temporal_pool_type (str): The temporal pool type. Choices are 'avg' or
'max'. Default: 'avg'.
spatial_pool_type (str): The spatial pool type. Choices are 'avg' or
'max'. Default: 'max'.
"""
fbo_dict = {'non_local': FBONonLocal, 'avg': FBOAvg, 'max': FBOMax}
def __init__(self,
lfb_cfg,
fbo_cfg,
temporal_pool_type='avg',
spatial_pool_type='max'):
super().__init__()
fbo_type = fbo_cfg.pop('type', 'non_local')
assert fbo_type in FBOHead.fbo_dict
assert temporal_pool_type in ['max', 'avg']
assert spatial_pool_type in ['max', 'avg']
self.lfb_cfg = copy.deepcopy(lfb_cfg)
self.fbo_cfg = copy.deepcopy(fbo_cfg)
self.lfb = LFB(**self.lfb_cfg)
self.fbo = self.fbo_dict[fbo_type](**self.fbo_cfg)
# Pool by default
if temporal_pool_type == 'avg':
self.temporal_pool = nn.AdaptiveAvgPool3d((1, None, None))
else:
self.temporal_pool = nn.AdaptiveMaxPool3d((1, None, None))
if spatial_pool_type == 'avg':
self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1))
else:
self.spatial_pool = nn.AdaptiveMaxPool3d((None, 1, 1))
[docs] def init_weights(self, pretrained=None):
"""Initialize the weights in the module.
Args:
pretrained (str, optional): Path to pre-trained weights.
Default: None.
"""
self.fbo.init_weights(pretrained=pretrained)
[docs] def sample_lfb(self, rois, img_metas):
"""Sample long-term features for each ROI feature."""
inds = rois[:, 0].type(torch.int64)
lt_feat_list = []
for ind in inds:
lt_feat_list.append(self.lfb[img_metas[ind]['img_key']].to())
lt_feat = torch.stack(lt_feat_list, dim=0)
# [N, lfb_channels, window_size * max_num_feat_per_step]
lt_feat = lt_feat.permute(0, 2, 1).contiguous()
return lt_feat.unsqueeze(-1).unsqueeze(-1)
[docs] def forward(self, x, rois, img_metas):
# [N, C, 1, 1, 1]
st_feat = self.temporal_pool(x)
st_feat = self.spatial_pool(st_feat)
identity = st_feat
# [N, C, window_size * num_feat_per_step, 1, 1]
lt_feat = self.sample_lfb(rois, img_metas).to(st_feat.device)
fbo_feat = self.fbo(st_feat, lt_feat)
out = torch.cat([identity, fbo_feat], dim=1)
return out
if mmdet_imported:
MMDET_SHARED_HEADS.register_module()(FBOHead)