import copy
import os.path as osp
import torch
from mmcv.utils import print_log
from ..core import mean_average_precision, mean_class_accuracy, top_k_accuracy
from .base import BaseDataset
from .registry import DATASETS
[docs]@DATASETS.register_module()
class RawframeDataset(BaseDataset):
"""Rawframe dataset for action recognition.
The dataset loads raw frames and apply specified transforms to return a
dict containing the frame tensors and other information.
The ann_file is a text file with multiple lines, and each line indicates
the directory to frames of a video, total frames of the video and
the label of a video, which are split with a whitespace.
Example of a annotation file:
.. code-block:: txt
some/directory-1 163 1
some/directory-2 122 1
some/directory-3 258 2
some/directory-4 234 2
some/directory-5 295 3
some/directory-6 121 3
Example of a multi-class annotation file:
.. code-block:: txt
some/directory-1 163 1 3 5
some/directory-2 122 1 2
some/directory-3 258 2
some/directory-4 234 2 4 6 8
some/directory-5 295 3
some/directory-6 121 3
Example of a with_offset annotation file (clips from long videos), each
line indicates the directory to frames of a video, the index of the start
frame, total frames of the video clip and the label of a video clip, which
are split with a whitespace.
.. code-block:: txt
some/directory-1 12 163 3
some/directory-2 213 122 4
some/directory-3 100 258 5
some/directory-4 98 234 2
some/directory-5 0 295 3
some/directory-6 50 121 3
Args:
ann_file (str): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transforms.
data_prefix (str): Path to a directory where videos are held.
Default: None.
test_mode (bool): Store True when building test or validation dataset.
Default: False.
filename_tmpl (str): Template for each filename.
Default: 'img_{:05}.jpg'.
with_offset (bool): Determines whether the offset information is in
ann_file. Default: False.
multi_class (bool): Determines whether it is a multi-class
recognition dataset. Default: False.
num_classes (int): Number of classes in the dataset. Default: None.
modality (str): Modality of data. Support 'RGB', 'Flow'.
Default: 'RGB'.
"""
def __init__(self,
ann_file,
pipeline,
data_prefix=None,
test_mode=False,
filename_tmpl='img_{:05}.jpg',
with_offset=False,
multi_class=False,
num_classes=None,
start_index=1,
modality='RGB'):
self.filename_tmpl = filename_tmpl
self.with_offset = with_offset
super().__init__(ann_file, pipeline, data_prefix, test_mode,
multi_class, num_classes, start_index, modality)
[docs] def load_annotations(self):
"""Load annotation file to get video information."""
if self.ann_file.endswith('.json'):
return self.load_json_annotations()
video_infos = []
with open(self.ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split()
video_info = {}
idx = 0
# idx for frame_dir
frame_dir = line_split[idx]
if self.data_prefix is not None:
frame_dir = osp.join(self.data_prefix, frame_dir)
video_info['frame_dir'] = frame_dir
idx += 1
if self.with_offset:
# idx for offset and total_frames
video_info['offset'] = int(line_split[idx])
video_info['total_frames'] = int(line_split[idx + 1])
idx += 2
else:
# idx for total_frames
video_info['total_frames'] = int(line_split[idx])
idx += 1
# idx for label[s]
label = [int(x) for x in line_split[idx:]]
assert len(label), f'missing label in line: {line}'
if self.multi_class:
assert self.num_classes is not None
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
video_info['label'] = onehot
else:
assert len(label) == 1
video_info['label'] = label[0]
video_infos.append(video_info)
return video_infos
[docs] def prepare_train_frames(self, idx):
"""Prepare the frames for training given the index."""
results = copy.deepcopy(self.video_infos[idx])
results['filename_tmpl'] = self.filename_tmpl
results['modality'] = self.modality
results['start_index'] = self.start_index
return self.pipeline(results)
[docs] def prepare_test_frames(self, idx):
"""Prepare the frames for testing given the index."""
results = copy.deepcopy(self.video_infos[idx])
results['filename_tmpl'] = self.filename_tmpl
results['modality'] = self.modality
results['start_index'] = self.start_index
return self.pipeline(results)
[docs] def evaluate(self,
results,
metrics='top_k_accuracy',
topk=(1, 5),
logger=None):
"""Evaluation in rawframe dataset.
Args:
results (list): Output results.
metrics (str | sequence[str]): Metrics to be performed.
Defaults: 'top_k_accuracy'.
logger (obj): Training logger. Defaults: None.
topk (int | tuple[int]): K value for top_k_accuracy metric.
Defaults: (1, 5).
logger (logging.Logger | None): Logger for recording.
Default: None.
Returns:
dict: Evaluation results dict.
"""
if not isinstance(results, list):
raise TypeError(f'results must be a list, but got {type(results)}')
assert len(results) == len(self), (
f'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
if not isinstance(topk, (int, tuple)):
raise TypeError(
f'topk must be int or tuple of int, but got {type(topk)}')
if isinstance(topk, int):
topk = (topk, )
metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
allowed_metrics = [
'top_k_accuracy', 'mean_class_accuracy', 'mean_average_precision'
]
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
eval_results = {}
gt_labels = [ann['label'] for ann in self.video_infos]
for metric in metrics:
msg = f'Evaluating {metric}...'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
if metric == 'top_k_accuracy':
top_k_acc = top_k_accuracy(results, gt_labels, topk)
log_msg = []
for k, acc in zip(topk, top_k_acc):
eval_results[f'top{k}_acc'] = acc
log_msg.append(f'\ntop{k}_acc\t{acc:.4f}')
log_msg = ''.join(log_msg)
print_log(log_msg, logger=logger)
continue
if metric == 'mean_class_accuracy':
mean_acc = mean_class_accuracy(results, gt_labels)
eval_results['mean_class_accuracy'] = mean_acc
log_msg = f'\nmean_acc\t{mean_acc:.4f}'
print_log(log_msg, logger=logger)
continue
if metric == 'mean_average_precision':
gt_labels = [label.cpu().numpy() for label in gt_labels]
mAP = mean_average_precision(results, gt_labels)
eval_results['mean_average_precision'] = mAP
log_msg = f'\nmean_average_precision\t{mAP:.4f}'
print_log(log_msg, logger=logger)
continue
return eval_results