注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.datasets.base_generation_dataset 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from pathlib import Path
from mmcv import scandir
from .base_dataset import BaseDataset
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
'.PPM', '.bmp', '.BMP', '.tif', '.TIF', '.tiff', '.TIFF')
[文档]class BaseGenerationDataset(BaseDataset):
"""Base class for generation datasets."""
[文档] @staticmethod
def scan_folder(path):
"""Obtain image path list (including sub-folders) from a given folder.
Args:
path (str | :obj:`Path`): Folder path.
Returns:
list[str]: Image list obtained from the given folder.
"""
if isinstance(path, (str, Path)):
path = str(path)
else:
raise TypeError("'path' must be a str or a Path object, "
f'but received {type(path)}.')
images = scandir(path, suffix=IMG_EXTENSIONS, recursive=True)
images = [osp.join(path, v) for v in images]
assert images, f'{path} has no valid image file.'
return images
[文档] def evaluate(self, results, logger=None):
"""Evaluating with saving generated images. (needs no metrics)
Args:
results (list[tuple]): The output of forward_test() of the model.
Return:
dict: Evaluation results dict.
"""
if not isinstance(results, list):
raise TypeError(f'results must be a list, but got {type(results)}')
assert len(results) == len(self), (
'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
results = [res['saved_flag'] for res in results]
saved_num = 0
for flag in results:
if flag:
saved_num += 1
# make a dict to show
eval_result = {'val_saved_number': saved_num}
return eval_result