Shortcuts

注意

您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志代码文档 以了解更多。

mmedit.datasets.base_dataset 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from abc import ABCMeta, abstractmethod

from torch.utils.data import Dataset

from .pipelines import Compose


[文档]class BaseDataset(Dataset, metaclass=ABCMeta): """Base class for datasets. All datasets should subclass it. All subclasses should overwrite: ``load_annotations``, supporting to load information and generate image lists. Args: pipeline (list[dict | callable]): A sequence of data transforms. test_mode (bool): If True, the dataset will work in test mode. Otherwise, in train mode. """ def __init__(self, pipeline, test_mode=False): super().__init__() self.test_mode = test_mode self.pipeline = Compose(pipeline)
[文档] @abstractmethod def load_annotations(self): """Abstract function for loading annotation. All subclasses should overwrite this function """
[文档] def prepare_train_data(self, idx): """Prepare training data. Args: idx (int): Index of the training batch data. Returns: dict: Returned training batch. """ results = copy.deepcopy(self.data_infos[idx]) return self.pipeline(results)
[文档] def prepare_test_data(self, idx): """Prepare testing data. Args: idx (int): Index for getting each testing batch. Returns: Tensor: Returned testing batch. """ results = copy.deepcopy(self.data_infos[idx]) return self.pipeline(results)
def __len__(self): """Length of the dataset. Returns: int: Length of the dataset. """ return len(self.data_infos) def __getitem__(self, idx): """Get item at each call. Args: idx (int): Index for getting each item. """ if self.test_mode: return self.prepare_test_data(idx) return self.prepare_train_data(idx)