Source code for mmaction.models.heads.ssn_head

import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from ..registry import HEADS


def parse_stage_config(stage_cfg):
    """Parse config of STPP for three stages.

    Args:
        stage_cfg (int | tuple[int]):
            Config of structured temporal pyramid pooling.

    Returns:
        tuple[tuple[int], int]:
            Config of structured temporal pyramid pooling and
            total number of parts(number of multipliers).
    """
    if isinstance(stage_cfg, int):
        return (stage_cfg, ), stage_cfg
    if isinstance(stage_cfg, tuple):
        return stage_cfg, sum(stage_cfg)
    raise ValueError(f'Incorrect STPP config {stage_cfg}')


class STPPTrain(nn.Module):
    """Structured temporal pyramid pooling for SSN at training.

    Args:
        stpp_stage (tuple): Config of structured temporal pyramid pooling.
            Default: (1, (1, 2), 1).
        num_segments_list (tuple): Number of segments to be sampled
            in three stages. Default: (2, 5, 2).
    """

    def __init__(self, stpp_stage=(1, (1, 2), 1), num_segments_list=(2, 5, 2)):
        super().__init__()

        starting_part, starting_multiplier = parse_stage_config(stpp_stage[0])
        course_part, course_multiplier = parse_stage_config(stpp_stage[1])
        ending_part, ending_multiplier = parse_stage_config(stpp_stage[2])

        self.num_multipliers = (
            starting_multiplier + course_multiplier + ending_multiplier)
        self.stpp_stages = (starting_part, course_part, ending_part)
        self.multiplier_list = (starting_multiplier, course_multiplier,
                                ending_multiplier)

        self.num_segments_list = num_segments_list

    @staticmethod
    def _extract_stage_feature(stage_feat, stage_parts, num_multipliers,
                               scale_factors, num_samples):
        """Extract stage feature based on structured temporal pyramid pooling.

        Args:
            stage_feat (torch.Tensor): Stage features to be STPP.
            stage_parts (tuple): Config of STPP.
            num_multipliers (int): Total number of parts in the stage.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.
            num_samples (int): Number of samples.

        Returns:
            torch.Tensor: Features of the stage.
        """
        stage_stpp_feat = []
        stage_len = stage_feat.size(1)
        for stage_part in stage_parts:
            ticks = torch.arange(0, stage_len + 1e-5,
                                 stage_len / stage_part).int()
            for i in range(stage_part):
                part_feat = stage_feat[:, ticks[i]:ticks[i + 1], :].mean(
                    dim=1) / num_multipliers
                if scale_factors is not None:
                    part_feat = (
                        part_feat * scale_factors.view(num_samples, 1))
                stage_stpp_feat.append(part_feat)
        return stage_stpp_feat

    def forward(self, x, scale_factors):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.

        Returns:
            tuple[torch.Tensor, torch.Tensor]:
                Features for predicting activity scores and
                completeness scores.
        """
        x0 = self.num_segments_list[0]
        x1 = x0 + self.num_segments_list[1]
        num_segments = x1 + self.num_segments_list[2]

        feat_dim = x.size(1)
        x = x.view(-1, num_segments, feat_dim)
        num_samples = x.size(0)

        scale_factors = scale_factors.view(-1, 2)

        stage_stpp_feats = []
        stage_stpp_feats.extend(
            self._extract_stage_feature(x[:, :x0, :], self.stpp_stages[0],
                                        self.multiplier_list[0],
                                        scale_factors[:, 0], num_samples))
        stage_stpp_feats.extend(
            self._extract_stage_feature(x[:, x0:x1, :], self.stpp_stages[1],
                                        self.multiplier_list[1], None,
                                        num_samples))
        stage_stpp_feats.extend(
            self._extract_stage_feature(x[:, x1:, :], self.stpp_stages[2],
                                        self.multiplier_list[2],
                                        scale_factors[:, 1], num_samples))
        stpp_feat = torch.cat(stage_stpp_feats, dim=1)

        course_feat = x[:, x0:x1, :].mean(dim=1)
        return course_feat, stpp_feat


class STPPTest(nn.Module):
    """Structured temporal pyramid pooling for SSN at testing.

    Args:
        num_classes (int): Number of classes to be classified.
        use_regression (bool): Whether to perform regression or not.
            Default: True.
        stpp_stage (tuple): Config of structured temporal pyramid pooling.
            Default: (1, (1, 2), 1).
    """

    def __init__(self,
                 num_classes,
                 use_regression=True,
                 stpp_stage=(1, (1, 2), 1)):
        super().__init__()

        self.activity_score_len = num_classes + 1
        self.complete_score_len = num_classes
        self.reg_score_len = num_classes * 2
        self.use_regression = use_regression

        starting_parts, starting_multiplier = parse_stage_config(stpp_stage[0])
        course_parts, course_multiplier = parse_stage_config(stpp_stage[1])
        ending_parts, ending_multiplier = parse_stage_config(stpp_stage[2])

        self.num_multipliers = (
            starting_multiplier + course_multiplier + ending_multiplier)
        if self.use_regression:
            self.feat_dim = (
                self.activity_score_len + self.num_multipliers *
                (self.complete_score_len + self.reg_score_len))
        else:
            self.feat_dim = (
                self.activity_score_len +
                self.num_multipliers * self.complete_score_len)
        self.stpp_stage = (starting_parts, course_parts, ending_parts)

        self.activity_slice = slice(0, self.activity_score_len)
        self.complete_slice = slice(
            self.activity_slice.stop, self.activity_slice.stop +
            self.complete_score_len * self.num_multipliers)
        self.reg_slice = slice(
            self.complete_slice.stop, self.complete_slice.stop +
            self.reg_score_len * self.num_multipliers)

    @staticmethod
    def _pyramids_pooling(out_scores, index, raw_scores, ticks, scale_factors,
                          score_len, stpp_stage):
        """Perform pyramids pooling.

        Args:
            out_scores (torch.Tensor): Scores to be returned.
            index (int): Index of output scores.
            raw_scores (torch.Tensor): Raw scores before STPP.
            ticks (list): Ticks of raw scores.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.
            score_len (int): Length of the score.
            stpp_stage (tuple): Config of STPP.
        """
        offset = 0
        for stage_idx, stage_cfg in enumerate(stpp_stage):
            if stage_idx == 0:
                scale_factor = scale_factors[0]
            elif stage_idx == len(stpp_stage) - 1:
                scale_factor = scale_factors[1]
            else:
                scale_factor = 1.0

            sum_parts = sum(stage_cfg)
            tick_left = ticks[stage_idx]
            tick_right = float(max(ticks[stage_idx] + 1, ticks[stage_idx + 1]))

            if tick_right <= 0 or tick_left >= raw_scores.size(0):
                offset += sum_parts
                continue
            for num_parts in stage_cfg:
                part_ticks = torch.arange(tick_left, tick_right + 1e-5,
                                          (tick_right - tick_left) /
                                          num_parts).int()

                for i in range(num_parts):
                    part_tick_left = part_ticks[i]
                    part_tick_right = part_ticks[i + 1]
                    if part_tick_right - part_tick_left >= 1:
                        raw_score = raw_scores[part_tick_left:part_tick_right,
                                               offset *
                                               score_len:(offset + 1) *
                                               score_len]
                        raw_scale_score = raw_score.mean(dim=0) * scale_factor
                        out_scores[index, :] += raw_scale_score.detach().cpu()
                    offset += 1

        return out_scores

    def forward(self, x, proposal_ticks, scale_factors):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.
            proposal_ticks (list): Ticks of proposals to be STPP.
            scale_factors (list): Ratios of the effective sampling lengths
                to augmented lengths.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                out_activity_scores (torch.Tensor): Activity scores
                out_complete_scores (torch.Tensor): Completeness scores.
                out_reg_scores (torch.Tensor): Regression scores.
        """
        assert x.size(1) == self.feat_dim
        num_ticks = proposal_ticks.size(0)

        out_activity_scores = torch.zeros((num_ticks, self.activity_score_len),
                                          dtype=x.dtype)
        raw_activity_scores = x[:, self.activity_slice]

        out_complete_scores = torch.zeros((num_ticks, self.complete_score_len),
                                          dtype=x.dtype)
        raw_complete_scores = x[:, self.complete_slice]

        if self.use_regression:
            out_reg_scores = torch.zeros((num_ticks, self.reg_score_len),
                                         dtype=x.dtype)
            raw_reg_scores = x[:, self.reg_slice]
        else:
            out_reg_scores = None
            raw_reg_scores = None

        for i in range(num_ticks):
            ticks = proposal_ticks[i]

            out_activity_scores[i, :] = raw_activity_scores[
                ticks[1]:max(ticks[1] + 1, ticks[2]), :].mean(dim=0)

            out_complete_scores = self._pyramids_pooling(
                out_complete_scores, i, raw_complete_scores, ticks,
                scale_factors[i], self.complete_score_len, self.stpp_stage)

            if self.use_regression:
                out_reg_scores = self._pyramids_pooling(
                    out_reg_scores, i, raw_reg_scores, ticks, scale_factors[i],
                    self.reg_score_len, self.stpp_stage)

        return out_activity_scores, out_complete_scores, out_reg_scores


[docs]@HEADS.register_module() class SSNHead(nn.Module): """The classification head for SSN. Args: dropout_ratio (float): Probability of dropout layer. Default: 0.8. in_channels (int): Number of channels for input data. Default: 1024. num_classes (int): Number of classes to be classified. Default: 20. consensus (dict): Config of segmental consensus. use_regression (bool): Whether to perform regression or not. Default: True. init_std (float): Std value for Initiation. Default: 0.001. """ def __init__(self, dropout_ratio=0.8, in_channels=1024, num_classes=20, consensus=dict( type='STPPTrain', standalong_classifier=True, stpp_cfg=(1, 1, 1), num_seg=(2, 5, 2)), use_regression=True, init_std=0.001): super().__init__() self.dropout_ratio = dropout_ratio self.num_classes = num_classes self.use_regression = use_regression self.init_std = init_std if self.dropout_ratio != 0: self.dropout = nn.Dropout(p=self.dropout_ratio) else: self.dropout = None # Based on this copy, the model will utilize different # structured temporal pyramid pooling at training and testing. # Warning: this copy cannot be removed. consensus_ = consensus.copy() consensus_type = consensus_.pop('type') if consensus_type == 'STPPTrain': self.consensus = STPPTrain(**consensus_) elif consensus_type == 'STPPTest': consensus_['num_classes'] = self.num_classes self.consensus = STPPTest(**consensus_) self.in_channels_activity = in_channels self.in_channels_complete = ( self.consensus.num_multipliers * in_channels) self.activity_fc = nn.Linear(in_channels, num_classes + 1) self.completeness_fc = nn.Linear(self.in_channels_complete, num_classes) if self.use_regression: self.regressor_fc = nn.Linear(self.in_channels_complete, num_classes * 2)
[docs] def init_weights(self): """Initiate the parameters from scratch.""" normal_init(self.activity_fc, std=self.init_std) normal_init(self.completeness_fc, std=self.init_std) if self.use_regression: normal_init(self.regressor_fc, std=self.init_std)
[docs] def prepare_test_fc(self, stpp_feat_multiplier): """Reorganize the shape of fully connected layer at testing, in order to improve testing efficiency. Args: stpp_feat_multiplier (int): Total number of parts. Returns: bool: Whether the shape transformation is ready for testing. """ in_features = self.activity_fc.in_features out_features = ( self.activity_fc.out_features + self.completeness_fc.out_features * stpp_feat_multiplier) if self.use_regression: out_features += ( self.regressor_fc.out_features * stpp_feat_multiplier) self.test_fc = nn.Linear(in_features, out_features) # Fetch weight and bias of the reorganized fc. complete_weight = self.completeness_fc.weight.data.view( self.completeness_fc.out_features, stpp_feat_multiplier, in_features).transpose(0, 1).contiguous().view(-1, in_features) complete_bias = self.completeness_fc.bias.data.view(1, -1).expand( stpp_feat_multiplier, self.completeness_fc.out_features ).contiguous().view(-1) / stpp_feat_multiplier weight = torch.cat((self.activity_fc.weight.data, complete_weight)) bias = torch.cat((self.activity_fc.bias.data, complete_bias)) if self.use_regression: reg_weight = self.regressor_fc.weight.data.view( self.regressor_fc.out_features, stpp_feat_multiplier, in_features).transpose(0, 1).contiguous().view(-1, in_features) reg_bias = self.regressor_fc.bias.data.view(1, -1).expand( stpp_feat_multiplier, self.regressor_fc.out_features ).contiguous().view(-1) / stpp_feat_multiplier weight = torch.cat((weight, reg_weight)) bias = torch.cat((bias, reg_bias)) self.test_fc.weight.data = weight self.test_fc.bias.data = bias return True
[docs] def forward(self, x, test_mode=False): """Defines the computation performed at every call.""" if not test_mode: x, proposal_scale_factor = x activity_feat, completeness_feat = self.consensus( x, proposal_scale_factor) if self.dropout is not None: activity_feat = self.dropout(activity_feat) completeness_feat = self.dropout(completeness_feat) activity_scores = self.activity_fc(activity_feat) complete_scores = self.completeness_fc(completeness_feat) if self.use_regression: bbox_preds = self.regressor_fc(completeness_feat) bbox_preds = bbox_preds.view(-1, self.completeness_fc.out_features, 2) else: bbox_preds = None return activity_scores, complete_scores, bbox_preds x, proposal_tick_list, scale_factor_list = x test_scores = self.test_fc(x) (activity_scores, completeness_scores, bbox_preds) = self.consensus(test_scores, proposal_tick_list, scale_factor_list) return (test_scores, activity_scores, completeness_scores, bbox_preds)