mmedit.apis.inferencers.text2image_inferencer 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List
import numpy as np
from mmengine import mkdir_or_exist
from torchvision.utils import save_image
from .base_mmedit_inferencer import BaseMMEditInferencer, InputsType, PredType
[文档]class Text2ImageInferencer(BaseMMEditInferencer):
"""inferencer that predicts with text2image models."""
[文档] func_kwargs = dict(
preprocess=['text'],
forward=[],
visualize=['result_out_dir'],
postprocess=[])
[文档] extra_parameters = dict(
scheduler_kwargs=None,
height=None,
width=None,
init_image=None,
batch_size=1,
num_inference_steps=1000,
skip_steps=0,
show_progress=False,
text_prompts=[],
image_prompts=[],
eta=0.8,
clip_guidance_scale=5000,
init_scale=1000,
tv_scale=0.,
sat_scale=0.,
range_scale=150,
cut_overview=[12] * 400 + [4] * 600,
cut_innercut=[4] * 400 + [12] * 600,
cut_ic_pow=[1] * 1000,
cut_icgray_p=[0.2] * 400 + [0] * 600,
cutn_batches=4,
seed=2022)
[文档] def preprocess(self, text: InputsType) -> Dict:
"""Process the inputs into a model-feedable format.
Args:
text(InputsType): text input for text-to-image model.
Returns:
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['text_prompts'] = text
return result
[文档] def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
image = self.model.infer(**inputs)['samples']
return image
[文档] def visualize(self,
preds: PredType,
result_out_dir: str = None) -> List[np.ndarray]:
"""Visualize predictions.
Args:
preds (List[Union[str, np.ndarray]]): Forward results
by the inferencer.
result_out_dir (str): Output directory of image.
Defaults to ''.
Returns:
List[np.ndarray]: Result of visualize
"""
if result_out_dir:
mkdir_or_exist(os.path.dirname(result_out_dir))
save_image(preds, result_out_dir, normalize=True)
return preds