Shortcuts

注意

您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志代码文档 以了解更多。

mmedit.models.restorers.basic_restorer 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numbers
import os.path as osp
import warnings
from copy import deepcopy

import mmcv
import torch
from mmcv.runner import auto_fp16

from mmedit.core import InceptionV3, psnr, ssim, tensor2img
from ..base import BaseModel
from ..builder import build_backbone, build_loss
from ..registry import MODELS


[文档]@MODELS.register_module() class BasicRestorer(BaseModel): """Basic model for image restoration. It must contain a generator that takes an image as inputs and outputs a restored image. It also has a pixel-wise loss for training. The subclasses should overwrite the function `forward_train`, `forward_test` and `train_step`. Args: generator (dict): Config for the generator structure. pixel_loss (dict): Config for pixel-wise loss. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. """ allowed_metrics = {'PSNR': psnr, 'SSIM': ssim} feature_based_metrics = ['FID', 'KID'] def __init__(self, generator, pixel_loss, train_cfg=None, test_cfg=None, pretrained=None): super().__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg # support fp16 self.fp16_enabled = False # generator self.generator = build_backbone(generator) self.init_weights(pretrained) # loss self.pixel_loss = build_loss(pixel_loss)
[文档] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ self.generator.init_weights(pretrained)
[文档] @auto_fp16(apply_to=('lq', )) def forward(self, lq, gt=None, test_mode=False, **kwargs): """Forward function. Args: lq (Tensor): Input lq images. gt (Tensor): Ground-truth image. Default: None. test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ if test_mode: return self.forward_test(lq, gt, **kwargs) return self.forward_train(lq, gt)
[文档] def forward_train(self, lq, gt): """Training forward function. Args: lq (Tensor): LQ Tensor with shape (n, c, h, w). gt (Tensor): GT Tensor with shape (n, c, h, w). Returns: Tensor: Output tensor. """ losses = dict() output = self.generator(lq) loss_pix = self.pixel_loss(output, gt) losses['loss_pix'] = loss_pix outputs = dict( losses=losses, num_samples=len(gt.data), results=dict(lq=lq.cpu(), gt=gt.cpu(), output=output.cpu())) return outputs
[文档] def evaluate(self, output, gt): """Evaluation function. Args: output (Tensor): Model output with shape (n, c, h, w). gt (Tensor): GT Tensor with shape (n, c, h, w). Returns: dict: Evaluation results. """ crop_border = self.test_cfg.crop_border output = tensor2img(output) gt = tensor2img(gt) eval_result = dict() inception_needed_metrics = [] for metric in self.test_cfg.metrics: if metric in self.feature_based_metrics: inception_needed_metrics.append(metric) # build with default args eval_result[metric] = dict(type=metric) elif (isinstance(metric, dict) and metric['type'] in self.feature_based_metrics): inception_needed_metrics.append(metric['type']) # build with user defined args eval_result[metric['type']] = deepcopy(metric) if inception_needed_metrics: warnings.warn("'_incetion_feat' is newly added to " '`self.test_cfg.metrics` to compute ' f'{inception_needed_metrics}.') if '_inception_feat' not in self.allowed_metrics: inception_style = self.test_cfg.get('inception_style', 'StyleGAN') device = 'cuda' if torch.cuda.is_available() else 'cpu' self.allowed_metrics['_inception_feat'] = InceptionV3( inception_style, device=device) self.test_cfg.metrics = tuple( self.test_cfg.metrics) + ('_inception_feat', ) for metric in self.test_cfg.metrics: if isinstance(metric, dict) or metric in self.feature_based_metrics: # skip FID and KID continue else: eval_result[metric] = self.allowed_metrics[metric](output, gt, crop_border) return eval_result
[文档] def forward_test(self, lq, gt=None, meta=None, save_image=False, save_path=None, iteration=None): """Testing forward function. Args: lq (Tensor): LQ Tensor with shape (n, c, h, w). gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None. save_image (bool): Whether to save image. Default: False. save_path (str): Path to save image. Default: None. iteration (int): Iteration for the saving image name. Default: None. Returns: dict: Output results. """ output = self.generator(lq) if self.test_cfg is not None and self.test_cfg.get('metrics', None): assert gt is not None, ( 'evaluation with metrics must have gt images.') results = dict(eval_result=self.evaluate(output, gt)) else: results = dict(lq=lq.cpu(), output=output.cpu()) if gt is not None: results['gt'] = gt.cpu() # save image if save_image: lq_path = meta[0]['lq_path'] folder_name = osp.splitext(osp.basename(lq_path))[0] if isinstance(iteration, numbers.Number): save_path = osp.join(save_path, folder_name, f'{folder_name}-{iteration + 1:06d}.png') elif iteration is None: save_path = osp.join(save_path, f'{folder_name}.png') else: raise ValueError('iteration should be number or None, ' f'but got {type(iteration)}') mmcv.imwrite(tensor2img(output), save_path) return results
[文档] def forward_dummy(self, img): """Used for computing network FLOPs. Args: img (Tensor): Input image. Returns: Tensor: Output image. """ out = self.generator(img) return out
[文档] def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ outputs = self(**data_batch, test_mode=False) loss, log_vars = self.parse_losses(outputs.pop('losses')) # optimize optimizer['generator'].zero_grad() loss.backward() optimizer['generator'].step() outputs.update({'log_vars': log_vars}) return outputs
[文档] def val_step(self, data_batch, **kwargs): """Validation step. Args: data_batch (dict): A batch of data. kwargs (dict): Other arguments for ``val_step``. Returns: dict: Returned output. """ output = self.forward_test(**data_batch, **kwargs) return output