from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from ..registry import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
[docs]@PIPELINES.register_module()
class ToTensor(object):
"""Convert some values in results dict to `torch.Tensor` type in data
loader pipeline.
Args:
keys (Sequence[str]): Required keys to be converted.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Performs the ToTensor formating.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return f'{self.__class__.__name__}(keys={self.keys})'
[docs]@PIPELINES.register_module()
class ToDataContainer(object):
"""Convert the data to DataContainer.
Args:
fields (Sequence[dict]): Required fields to be converted
with keys and attributes. E.g.
fields=(dict(key='gt_bbox', stack=False),).
"""
def __init__(self, fields):
self.fields = fields
def __call__(self, results):
"""Performs the ToDataContainer formating.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
for field in self.fields:
_field = field.copy()
key = _field.pop('key')
results[key] = DC(results[key], **_field)
return results
def __repr__(self):
return self.__class__.__name__ + f'(fields={self.fields})'
[docs]@PIPELINES.register_module()
class ImageToTensor(object):
"""Convert image type to `torch.Tensor` type.
Args:
keys (Sequence[str]): Required keys to be converted.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Performs the ImageToTensor formating.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
for key in self.keys:
results[key] = to_tensor(results[key].transpose(2, 0, 1))
return results
def __repr__(self):
return f'{self.__class__.__name__}(keys={self.keys})'
[docs]@PIPELINES.register_module()
class Transpose(object):
"""Transpose image channels to a given order.
Args:
keys (Sequence[str]): Required keys to be converted.
order (Sequence[int]): Image channel order.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
"""Performs the Transpose formating.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return (f'{self.__class__.__name__}('
f'keys={self.keys}, order={self.order})')
[docs]@PIPELINES.register_module()
class Collect(object):
"""Collect data from the loader relevant to the specific task.
This keeps the items in ``keys`` as it is, and collect items in
``meta_keys`` into a meta item called ``meta_name``.This is usually
the last stage of the data loader pipeline.
For example, when keys='imgs', meta_keys=('filename', 'label',
'original_shape'), meta_name='img_meta', the results will be a dict with
keys 'imgs' and 'img_meta', where 'img_meta' is a DataContainer of another
dict with keys 'filename', 'label', 'original_shape'.
Args:
keys (Sequence[str]): Required keys to be collected.
meta_name (str): The name of the key that contains meta infomation.
This key is always populated. Default: "img_meta".
meta_keys (Sequence[str]): Keys that are collected under meta_name.
The contents of the ``meta_name`` dictionary depends on
``meta_keys``.
By default this includes:
- "filename": path to the image file
- "label": label of the image file
- "original_shape": original shape of the image as a tuple
(h, w, c)
- "img_shape": shape of the image input to the network as a tuple
(h, w, c). Note that images may be zero padded on the
bottom/right, if the batch tensor is larger than this shape.
- "pad_shape": image shape after padding
- "flip_direction": a str in ("horiziontal", "vertival") to
indicate if the image is fliped horizontally or vertically.
- "img_norm_cfg": a dict of normalization information:
- mean - per channel mean subtraction
- std - per channel std divisor
- to_rgb - bool indicating if bgr was converted to rgb
"""
def __init__(self,
keys,
meta_keys=('filename', 'label', 'original_shape', 'img_shape',
'pad_shape', 'flip_direction', 'img_norm_cfg'),
meta_name='img_meta'):
self.keys = keys
self.meta_keys = meta_keys
self.meta_name = meta_name
def __call__(self, results):
"""Performs the Collect formating.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
data = {}
for key in self.keys:
data[key] = results[key]
if len(self.meta_keys) != 0:
meta = {}
for key in self.meta_keys:
meta[key] = results[key]
data[self.meta_name] = DC(meta, cpu_only=True)
return data
def __repr__(self):
return (f'{self.__class__.__name__}('
f'keys={self.keys}, meta_keys={self.meta_keys})')