注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.models.backbones.encoder_decoders.simple_encoder_decoder 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmedit.models.builder import build_component
from mmedit.models.registry import BACKBONES
[文档]@BACKBONES.register_module()
class SimpleEncoderDecoder(nn.Module):
"""Simple encoder-decoder model from matting.
Args:
encoder (dict): Config of the encoder.
decoder (dict): Config of the decoder.
"""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = build_component(encoder)
if hasattr(self.encoder, 'out_channels'):
decoder['in_channels'] = self.encoder.out_channels
self.decoder = build_component(decoder)
def init_weights(self, pretrained=None):
self.encoder.init_weights(pretrained)
self.decoder.init_weights()
[文档] def forward(self, *args, **kwargs):
"""Forward function.
Returns:
Tensor: The output tensor of the decoder.
"""
out = self.encoder(*args, **kwargs)
out = self.decoder(out)
return out