Shortcuts

mmedit.apis.inferencers.image_super_resolution_inferencer 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List

import mmcv
import numpy as np
import torch
from mmengine import mkdir_or_exist
from mmengine.dataset import Compose
from mmengine.dataset.utils import default_collate as collate

from mmedit.utils import tensor2img
from .base_mmedit_inferencer import BaseMMEditInferencer, InputsType, PredType


[文档]class ImageSuperResolutionInferencer(BaseMMEditInferencer): """inferencer that predicts with restoration models."""
[文档] func_kwargs = dict( preprocess=['img', 'ref'], forward=[], visualize=['result_out_dir'], postprocess=[])
[文档] def preprocess(self, img: InputsType, ref: InputsType = None) -> Dict: """Process the inputs into a model-feedable format. Args: img(InputsType): Image to be restored by models. ref(InputsType): Reference image for resoration models. Defaults to None. Returns: data(Dict): Results of preprocess. """ cfg = self.model.cfg device = next(self.model.parameters()).device # model device # select the data pipeline if cfg.get('inference_pipeline', None): test_pipeline = cfg.inference_pipeline elif cfg.get('demo_pipeline', None): test_pipeline = cfg.demo_pipeline elif cfg.get('test_pipeline', None): test_pipeline = cfg.test_pipeline else: test_pipeline = cfg.val_pipeline keys_to_remove = ['gt', 'gt_path'] for key in keys_to_remove: for pipeline in list(test_pipeline): if 'key' in pipeline and key == pipeline['key']: test_pipeline.remove(pipeline) if 'keys' in pipeline and key in pipeline['keys']: pipeline['keys'].remove(key) if len(pipeline['keys']) == 0: test_pipeline.remove(pipeline) if 'meta_keys' in pipeline and key in pipeline['meta_keys']: pipeline['meta_keys'].remove(key) # build the data pipeline test_pipeline = Compose(test_pipeline) # prepare data if ref: # Ref-SR data = dict(img_path=img, gt_path=ref) else: # SISR data = dict(img_path=img) _data = test_pipeline(data) data = dict() data_preprocessor = cfg['model']['data_preprocessor'] mean = torch.Tensor(data_preprocessor['mean']).view([3, 1, 1]) std = torch.Tensor(data_preprocessor['std']).view([3, 1, 1]) data['inputs'] = (_data['inputs'] - mean) / std data = collate([data]) if ref: data['data_samples'] = [_data['data_samples']] if 'cuda' in str(device): data['inputs'] = data['inputs'].cuda() if ref: data['data_samples'][0] = data['data_samples'][0].cuda() return data
[文档] def forward(self, inputs: InputsType) -> PredType: """Forward the inputs to the model.""" with torch.no_grad(): result = self.model(mode='tensor', **inputs) return result
[文档] def visualize(self, preds: PredType, result_out_dir: str = None) -> List[np.ndarray]: """Visualize predictions. Args: preds (List[Union[str, np.ndarray]]): Forward results by the inferencer. data (List[Dict]): Not needed by this kind of inferencer. result_out_dir (str): Output directory of image. Defaults to ''. Returns: List[np.ndarray]: Result of visualize """ results = tensor2img(preds[0]) if result_out_dir: mkdir_or_exist(os.path.dirname(result_out_dir)) mmcv.imwrite(results, result_out_dir) return results