Shortcuts

注意

您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志代码文档 以了解更多。

mmedit.models.components.discriminators.gl_disc 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger
from .multi_layer_disc import MultiLayerDiscriminator


[文档]@COMPONENTS.register_module() class GLDiscs(nn.Module): """Discriminators in Global&Local. This discriminator contains a local discriminator and a global discriminator as described in the original paper: Globally and locally Consistent Image Completion Args: global_disc_cfg (dict): Config dict to build global discriminator. local_disc_cfg (dict): Config dict to build local discriminator. """ def __init__(self, global_disc_cfg, local_disc_cfg): super().__init__() self.global_disc = MultiLayerDiscriminator(**global_disc_cfg) self.local_disc = MultiLayerDiscriminator(**local_disc_cfg) self.fc = nn.Linear(2048, 1, bias=True)
[文档] def forward(self, x): """Forward function. Args: x (tuple[torch.Tensor]): Contains global image and the local image patch. Returns: tuple[torch.Tensor]: Contains the prediction from discriminators \ in global image and local image patch. """ g_img, l_img = x g_pred = self.global_disc(g_img) l_pred = self.local_disc(l_img) pred = self.fc(torch.cat([g_pred, l_pred], dim=1)) return pred
[文档] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): # Here, we only initialize the module with fc layer since the # conv and norm layers has been initialized in `ConvModule`. if isinstance(m, nn.Linear): nn.init.normal_(m.weight.data, 0.0, 0.02) nn.init.constant_(m.bias.data, 0.0) else: raise TypeError('pretrained must be a str or None')