mmedit.apis.inferencers 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from mmedit.utils import ConfigType
from .colorization_inferencer import ColorizationInferencer
from .conditional_inferencer import ConditionalInferencer
from .controlnet_animation_inferencer import ControlnetAnimationInferencer
from .eg3d_inferencer import EG3DInferencer
from .image_super_resolution_inferencer import ImageSuperResolutionInferencer
# yapf: disable
from .inference_functions import (calculate_grid_size, colorization_inference,
delete_cfg, init_model, inpainting_inference,
matting_inference,
restoration_face_inference,
restoration_inference,
restoration_video_inference,
sample_conditional_model,
sample_img2img_model,
sample_unconditional_model, set_random_seed,
video_interpolation_inference)
# yapf: enable
from .inpainting_inferencer import InpaintingInferencer
from .matting_inferencer import MattingInferencer
from .text2image_inferencer import Text2ImageInferencer
from .translation_inferencer import TranslationInferencer
from .unconditional_inferencer import UnconditionalInferencer
from .video_interpolation_inferencer import VideoInterpolationInferencer
from .video_restoration_inferencer import VideoRestorationInferencer
__all__ = [
'init_model', 'delete_cfg', 'set_random_seed', 'matting_inference',
'inpainting_inference', 'restoration_inference',
'restoration_video_inference', 'restoration_face_inference',
'video_interpolation_inference', 'sample_conditional_model',
'sample_unconditional_model', 'sample_img2img_model',
'colorization_inference', 'calculate_grid_size', 'ColorizationInferencer',
'ConditionalInferencer', 'EG3DInferencer', 'InpaintingInferencer',
'MattingInferencer', 'ImageSuperResolutionInferencer',
'Text2ImageInferencer', 'TranslationInferencer', 'UnconditionalInferencer',
'VideoInterpolationInferencer', 'VideoRestorationInferencer',
'ControlnetAnimationInferencer'
]
[文档]class Inferencers:
"""Class to assign task to different inferencers.
Args:
task (str): Inferencer task.
config (str or ConfigType): Model config or the path to it.
ckpt (str, optional): Path to the checkpoint.
device (str, optional): Device to run inference. If None, the best
device will be automatically used.
seed (int): The random seed used in inference. Defaults to 2022.
"""
def __init__(self,
task: Optional[str] = None,
config: Optional[Union[ConfigType, str]] = None,
ckpt: Optional[str] = None,
device: torch.device = None,
extra_parameters: Optional[Dict] = None,
seed: int = 2022) -> None:
self.task = task
if self.task in ['conditional', 'Conditional GANs']:
self.inferencer = ConditionalInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['colorization', 'Colorization']:
self.inferencer = ColorizationInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['unconditional', 'Unconditional GANs']:
self.inferencer = UnconditionalInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['matting', 'Matting']:
self.inferencer = MattingInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['inpainting', 'Inpainting']:
self.inferencer = InpaintingInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['translation', 'Image2Image']:
self.inferencer = TranslationInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['Image super-resolution', 'Image Super-Resolution']:
self.inferencer = ImageSuperResolutionInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['video_restoration', 'Video Super-Resolution']:
self.inferencer = VideoRestorationInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['video_interpolation', 'Video Interpolation']:
self.inferencer = VideoInterpolationInferencer(
config, ckpt, device, extra_parameters)
elif self.task in [
'text2image', 'Text2Image', 'Text2Image, Image2Image'
]:
self.inferencer = Text2ImageInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['3D_aware_generation', '3D-aware Generation']:
self.inferencer = EG3DInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['controlnet_animation']:
self.inferencer = ControlnetAnimationInferencer(config)
else:
raise ValueError(f'Unknown inferencer task: {self.task}')
def __call__(self, **kwargs) -> Union[Dict, List[Dict]]:
"""Call the inferencer.
Args:
kwargs: Keyword arguments for the inferencer.
Returns:
Union[Dict, List[Dict]]: Results of inference pipeline.
"""
return self.inferencer(**kwargs)
def get_extra_parameters(self) -> List[str]:
"""Each inferencer may has its own parameters. Call this function to
get these parameters.
Returns:
List[str]: List of unique parameters.
"""
return self.inferencer.get_extra_parameters()