mmedit.models.base_models.base_edit_model¶
Module Contents¶
Classes¶
Base model for image and video editing. |
- class mmedit.models.base_models.base_edit_model.BaseEditModel(generator: dict, pixel_loss: dict, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None)[源代码]¶
Bases:
mmengine.model.BaseModelBase model for image and video editing.
It must contain a generator that takes frames as inputs and outputs an interpolated frame. It also has a pixel-wise loss for training.
- 参数
generator (dict) – Config for the generator structure.
pixel_loss (dict) – Config for pixel-wise loss.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
init_cfg (dict, optional) – The weight initialized config for
BaseModule.data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor.
- init_cfg¶
Initialization config dict.
- Type
dict, optional
- data_preprocessor¶
Used for pre-processing data sampled by dataloader to the format accepted by
forward(). Default: None.- Type
BaseDataPreprocessor
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmedit.structures.EditDataSample]] = None, mode: str = 'tensor', **kwargs) Union[torch.Tensor, List[mmedit.structures.EditDataSample], dict][源代码]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forwardmethod of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputsanddata_samplesprocessed bydata_preprocessor, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forwardwill be called byBaseModel.train_step,BaseModel.val_stepandBaseModel.val_stepdirectly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_stepwill first callDistributedDataParallel.forwardto enable automatic gradient synchronization, and then callforwardto get training loss.- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.mode (str) –
mode should be one of
loss,predictandtensor. Default: ‘tensor’.loss: Called bytrain_stepand return lossdictused for loggingpredict: Called byval_stepandtest_stepand return list ofBaseDataElementresults used for computing metric.tensor: Called by custom use to getTensortype results.
- 返回
If
mode == loss, return adictof loss tensor used for backward and logging.If
mode == predict, return alistofBaseDataElementfor computing metric and getting inference result.If
mode == tensor, return a tensor ortupleof tensor ordictor tensor for custom use.
- 返回类型
ForwardResults
- convert_to_datasample(predictions: mmedit.structures.EditDataSample, data_samples: mmedit.structures.EditDataSample, inputs: Optional[torch.Tensor]) List[mmedit.structures.EditDataSample][源代码]¶
Add predictions and destructed inputs (if passed) to data samples.
- 参数
predictions (EditDataSample) – The predictions of the model.
data_samples (EditDataSample) – The data samples loaded from dataloader.
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
- 返回
Modified data samples.
- 返回类型
List[EditDataSample]
- forward_tensor(inputs: torch.Tensor, data_samples: Optional[List[mmedit.structures.EditDataSample]] = None, **kwargs) torch.Tensor[源代码]¶
Forward tensor. Returns result of simple forward.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- 返回
result of simple forward.
- 返回类型
Tensor
- forward_inference(inputs: torch.Tensor, data_samples: Optional[List[mmedit.structures.EditDataSample]] = None, **kwargs) mmedit.structures.EditDataSample[源代码]¶
Forward inference. Returns predictions of validation, testing, and simple inference.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- 返回
predictions.
- 返回类型
- forward_train(inputs: torch.Tensor, data_samples: Optional[List[mmedit.structures.EditDataSample]] = None, **kwargs) Dict[str, torch.Tensor][源代码]¶
Forward training. Returns dict of losses of training.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor.
- 返回
Dict of losses.
- 返回类型
dict