注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.models.backbones.encoder_decoders.encoders.fba_encoder 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from mmedit.models.registry import COMPONENTS
from .resnet import ResNet
[文档]@COMPONENTS.register_module()
class FBAResnetDilated(ResNet):
"""ResNet-based encoder for FBA image matting."""
[文档] def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (N, C, H, W).
Returns:
Tensor: Output tensor.
"""
# x: (merged_t, trimap_t, two_channel_trimap,merged)
# t refers to transformed.
two_channel_trimap = x[:, 9:11]
merged = x[:, 11:14]
x = x[:, 0:11, ...]
conv_out = [x]
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.activate(x)
conv_out.append(x)
x = self.maxpool(x)
x = self.layer1(x)
conv_out.append(x)
x = self.layer2(x)
conv_out.append(x)
x = self.layer3(x)
conv_out.append(x)
x = self.layer4(x)
conv_out.append(x)
return {
'conv_out': conv_out,
'merged': merged,
'two_channel_trimap': two_channel_trimap
}