注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.models.backbones.encoder_decoders.encoders.gl_encoder 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmedit.models.registry import COMPONENTS
[文档]@COMPONENTS.register_module()
class GLEncoder(nn.Module):
"""Encoder used in Global&Local model.
This implementation follows:
Globally and locally Consistent Image Completion
Args:
norm_cfg (dict): Config dict to build norm layer.
act_cfg (dict): Config dict for activation layer, "relu" by default.
"""
def __init__(self, norm_cfg=None, act_cfg=dict(type='ReLU')):
super().__init__()
channel_list = [64, 128, 128, 256, 256, 256]
kernel_size_list = [5, 3, 3, 3, 3, 3]
stride_list = [1, 2, 1, 2, 1, 1]
in_channels = 4
for i in range(6):
ks = kernel_size_list[i]
padding = (ks - 1) // 2
self.add_module(
f'enc{i + 1}',
ConvModule(
in_channels,
channel_list[i],
kernel_size=ks,
stride=stride_list[i],
padding=padding,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
in_channels = channel_list[i]
[文档] def forward(self, x):
"""Forward Function.
Args:
x (torch.Tensor): Input tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Output tensor with shape of (n, c, h', w').
"""
for i in range(6):
x = getattr(self, f'enc{i + 1}')(x)
return x