mmedit.models.editors.gca.gca¶
Module Contents¶
Classes¶
Guided Contextual Attention image matting model. |
- class mmedit.models.editors.gca.gca.GCA(data_preprocessor, backbone, loss_alpha=None, init_cfg: Optional[dict] = None, train_cfg=None, test_cfg=None)[源代码]¶
Bases:
mmedit.models.base_models.BaseMattorGuided Contextual Attention image matting model.
https://arxiv.org/abs/2001.04069
- 参数
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor.backbone (dict) – Config of backbone.
loss_alpha (dict) – Config of the alpha prediction loss. Default: None.
init_cfg (dict, optional) – Initialization config dict. Default: None.
train_cfg (dict) – Config of training. In
train_cfg,train_backboneshould be specified. If the model has a refiner,train_refinershould be specified.test_cfg (dict) – Config of testing. In
test_cfg, If the model has a refiner,train_refinershould be specified.
- _forward(inputs)[源代码]¶
Forward function.
- 参数
inputs (torch.Tensor) – Input tensor.
- 返回
Output tensor.
- 返回类型
Tensor
- _forward_test(inputs)[源代码]¶
Forward function for testing GCA model.
- 参数
inputs (torch.Tensor) – batch input tensor.
- 返回
Output tensor of model.
- 返回类型
Tensor
- _forward_train(inputs, data_samples)[源代码]¶
Forward function for training GCA model.
- 参数
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor.data_samples (List[BaseDataElement]) – data samples collated by
data_preprocessor.
- 返回
Contains the loss items and batch information.
- 返回类型
dict