Source code for mmaction.datasets.rawframe_dataset

import copy
import os.path as osp

import torch
from mmcv.utils import print_log

from ..core import mean_average_precision, mean_class_accuracy, top_k_accuracy
from .base import BaseDataset
from .registry import DATASETS


[docs]@DATASETS.register_module() class RawframeDataset(BaseDataset): """Rawframe dataset for action recognition. The dataset loads raw frames and apply specified transforms to return a dict containing the frame tensors and other information. The ann_file is a text file with multiple lines, and each line indicates the directory to frames of a video, total frames of the video and the label of a video, which are split with a whitespace. Example of a annotation file: .. code-block:: txt some/directory-1 163 1 some/directory-2 122 1 some/directory-3 258 2 some/directory-4 234 2 some/directory-5 295 3 some/directory-6 121 3 Example of a multi-class annotation file: .. code-block:: txt some/directory-1 163 1 3 5 some/directory-2 122 1 2 some/directory-3 258 2 some/directory-4 234 2 4 6 8 some/directory-5 295 3 some/directory-6 121 3 Example of a with_offset annotation file (clips from long videos), each line indicates the directory to frames of a video, the index of the start frame, total frames of the video clip and the label of a video clip, which are split with a whitespace. .. code-block:: txt some/directory-1 12 163 3 some/directory-2 213 122 4 some/directory-3 100 258 5 some/directory-4 98 234 2 some/directory-5 0 295 3 some/directory-6 50 121 3 Args: ann_file (str): Path to the annotation file. pipeline (list[dict | callable]): A sequence of data transforms. data_prefix (str): Path to a directory where videos are held. Default: None. test_mode (bool): Store True when building test or validation dataset. Default: False. filename_tmpl (str): Template for each filename. Default: 'img_{:05}.jpg'. with_offset (bool): Determines whether the offset information is in ann_file. Default: False. multi_class (bool): Determines whether it is a multi-class recognition dataset. Default: False. num_classes (int): Number of classes in the dataset. Default: None. modality (str): Modality of data. Support 'RGB', 'Flow'. Default: 'RGB'. """ def __init__(self, ann_file, pipeline, data_prefix=None, test_mode=False, filename_tmpl='img_{:05}.jpg', with_offset=False, multi_class=False, num_classes=None, start_index=1, modality='RGB'): self.filename_tmpl = filename_tmpl self.with_offset = with_offset super().__init__(ann_file, pipeline, data_prefix, test_mode, multi_class, num_classes, start_index, modality)
[docs] def load_annotations(self): """Load annotation file to get video information.""" if self.ann_file.endswith('.json'): return self.load_json_annotations() video_infos = [] with open(self.ann_file, 'r') as fin: for line in fin: line_split = line.strip().split() video_info = {} idx = 0 # idx for frame_dir frame_dir = line_split[idx] if self.data_prefix is not None: frame_dir = osp.join(self.data_prefix, frame_dir) video_info['frame_dir'] = frame_dir idx += 1 if self.with_offset: # idx for offset and total_frames video_info['offset'] = int(line_split[idx]) video_info['total_frames'] = int(line_split[idx + 1]) idx += 2 else: # idx for total_frames video_info['total_frames'] = int(line_split[idx]) idx += 1 # idx for label[s] label = [int(x) for x in line_split[idx:]] assert len(label), f'missing label in line: {line}' if self.multi_class: assert self.num_classes is not None onehot = torch.zeros(self.num_classes) onehot[label] = 1.0 video_info['label'] = onehot else: assert len(label) == 1 video_info['label'] = label[0] video_infos.append(video_info) return video_infos
[docs] def prepare_train_frames(self, idx): """Prepare the frames for training given the index.""" results = copy.deepcopy(self.video_infos[idx]) results['filename_tmpl'] = self.filename_tmpl results['modality'] = self.modality results['start_index'] = self.start_index return self.pipeline(results)
[docs] def prepare_test_frames(self, idx): """Prepare the frames for testing given the index.""" results = copy.deepcopy(self.video_infos[idx]) results['filename_tmpl'] = self.filename_tmpl results['modality'] = self.modality results['start_index'] = self.start_index return self.pipeline(results)
[docs] def evaluate(self, results, metrics='top_k_accuracy', topk=(1, 5), logger=None): """Evaluation in rawframe dataset. Args: results (list): Output results. metrics (str | sequence[str]): Metrics to be performed. Defaults: 'top_k_accuracy'. logger (obj): Training logger. Defaults: None. topk (int | tuple[int]): K value for top_k_accuracy metric. Defaults: (1, 5). logger (logging.Logger | None): Logger for recording. Default: None. Returns: dict: Evaluation results dict. """ if not isinstance(results, list): raise TypeError(f'results must be a list, but got {type(results)}') assert len(results) == len(self), ( f'The length of results is not equal to the dataset len: ' f'{len(results)} != {len(self)}') if not isinstance(topk, (int, tuple)): raise TypeError( f'topk must be int or tuple of int, but got {type(topk)}') if isinstance(topk, int): topk = (topk, ) metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics] allowed_metrics = [ 'top_k_accuracy', 'mean_class_accuracy', 'mean_average_precision' ] for metric in metrics: if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') eval_results = {} gt_labels = [ann['label'] for ann in self.video_infos] for metric in metrics: msg = f'Evaluating {metric}...' if logger is None: msg = '\n' + msg print_log(msg, logger=logger) if metric == 'top_k_accuracy': top_k_acc = top_k_accuracy(results, gt_labels, topk) log_msg = [] for k, acc in zip(topk, top_k_acc): eval_results[f'top{k}_acc'] = acc log_msg.append(f'\ntop{k}_acc\t{acc:.4f}') log_msg = ''.join(log_msg) print_log(log_msg, logger=logger) continue if metric == 'mean_class_accuracy': mean_acc = mean_class_accuracy(results, gt_labels) eval_results['mean_class_accuracy'] = mean_acc log_msg = f'\nmean_acc\t{mean_acc:.4f}' print_log(log_msg, logger=logger) continue if metric == 'mean_average_precision': gt_labels = [label.cpu().numpy() for label in gt_labels] mAP = mean_average_precision(results, gt_labels) eval_results['mean_average_precision'] = mAP log_msg = f'\nmean_average_precision\t{mAP:.4f}' print_log(log_msg, logger=logger) continue return eval_results