Shortcuts

注意

您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志代码文档 以了解更多。

mmedit.datasets.img_inpainting_dataset 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path

from .base_dataset import BaseDataset
from .registry import DATASETS


[文档]@DATASETS.register_module() class ImgInpaintingDataset(BaseDataset): """Image dataset for inpainting.""" def __init__(self, ann_file, pipeline, data_prefix=None, test_mode=False): super().__init__(pipeline, test_mode) self.ann_file = str(ann_file) self.data_prefix = str(data_prefix) self.data_infos = self.load_annotations()
[文档] def load_annotations(self): """Load annotations for dataset. Returns: list[dict]: Contain dataset annotations. """ with open(self.ann_file, 'r') as f: img_infos = [] for idx, line in enumerate(f): line = line.strip() _info = dict() line_split = line.split(' ') _info = dict( gt_img_path=Path(self.data_prefix).joinpath( line_split[0]).as_posix(), gt_img_idx=idx) img_infos.append(_info) return img_infos
def evaluate(self, outputs, logger=None, **kwargs): metric_keys = outputs[0]['eval_result'].keys() stats = {} for key in metric_keys: val = sum([x['eval_result'][key] for x in outputs]) val /= self.__len__() stats[key] = val return stats