Source code for mmaction.models.backbones.tanet

from copy import deepcopy

import torch.nn as nn
from torch.utils import checkpoint as cp

from ..common import TAM
from ..registry import BACKBONES
from .resnet import Bottleneck, ResNet


class TABlock(nn.Module):
    """Temporal Adaptive Block (TA-Block) for TANet.

    This block is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO
    RECOGNITION <https://arxiv.org/pdf/2005.06803>`_

    The temporal adaptive module (TAM) is embedded into ResNet-Block
    after the first Conv2D, which turns the vanilla ResNet-Block
    into TA-Block.

    Args:
        block (nn.Module): Residual blocks to be substituted.
        num_segments (int): Number of frame segments.
        tam_cfg (dict): Config for temporal adaptive module (TAM).
            Default: dict().
    """

    def __init__(self, block, num_segments, tam_cfg=dict()):
        super().__init__()
        self.tam_cfg = deepcopy(tam_cfg)
        self.block = block
        self.num_segments = num_segments
        self.tam = TAM(
            in_channels=block.conv1.out_channels,
            num_segments=num_segments,
            **self.tam_cfg)

        if not isinstance(self.block, Bottleneck):
            raise NotImplementedError('TA-Blocks have not been fully '
                                      'implemented except the pattern based '
                                      'on Bottleneck block.')

    def forward(self, x):
        if isinstance(self.block, Bottleneck):

            def _inner_forward(x):
                """Forward wrapper for utilizing checkpoint."""
                identity = x

                out = self.block.conv1(x)
                out = self.tam(out)
                out = self.block.conv2(out)
                out = self.block.conv3(out)

                if self.block.downsample is not None:
                    identity = self.block.downsample(x)

                out = out + identity

                return out

            if self.block.with_cp and x.requires_grad:
                out = cp.checkpoint(_inner_forward, x)
            else:
                out = _inner_forward(x)

            out = self.block.relu(out)

            return out


[docs]@BACKBONES.register_module() class TANet(ResNet): """Temporal Adaptive Network (TANet) backbone. This backbone is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO RECOGNITION <https://arxiv.org/pdf/2005.06803>`_ Embedding the temporal adaptive module (TAM) into ResNet to instantiate TANet. Args: depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. num_segments (int): Number of frame segments. tam_cfg (dict | None): Config for temporal adaptive module (TAM). Default: dict(). **kwargs (keyword arguments, optional): Arguments for ResNet except ```depth```. """ def __init__(self, depth, num_segments, tam_cfg=dict(), **kwargs): super().__init__(depth, **kwargs) assert num_segments >= 3 self.num_segments = num_segments self.tam_cfg = deepcopy(tam_cfg)
[docs] def init_weights(self): super().init_weights() self.make_tam_modeling()
[docs] def make_tam_modeling(self): """Replace ResNet-Block with TA-Block.""" def make_tam_block(stage, num_segments, tam_cfg=dict()): blocks = list(stage.children()) for i, block in enumerate(blocks): blocks[i] = TABlock(block, num_segments, deepcopy(tam_cfg)) return nn.Sequential(*blocks) for i in range(self.num_stages): layer_name = f'layer{i + 1}' res_layer = getattr(self, layer_name) setattr(self, layer_name, make_tam_block(res_layer, self.num_segments, self.tam_cfg))