Shortcuts

注意

您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志代码文档 以了解更多。

mmedit.models.backbones.encoder_decoders.simple_encoder_decoder 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn

from mmedit.models.builder import build_component
from mmedit.models.registry import BACKBONES


[文档]@BACKBONES.register_module() class SimpleEncoderDecoder(nn.Module): """Simple encoder-decoder model from matting. Args: encoder (dict): Config of the encoder. decoder (dict): Config of the decoder. """ def __init__(self, encoder, decoder): super().__init__() self.encoder = build_component(encoder) if hasattr(self.encoder, 'out_channels'): decoder['in_channels'] = self.encoder.out_channels self.decoder = build_component(decoder) def init_weights(self, pretrained=None): self.encoder.init_weights(pretrained) self.decoder.init_weights()
[文档] def forward(self, *args, **kwargs): """Forward function. Returns: Tensor: The output tensor of the decoder. """ out = self.encoder(*args, **kwargs) out = self.decoder(out) return out