注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.models.components.discriminators.gl_disc 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint
from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger
from .multi_layer_disc import MultiLayerDiscriminator
[文档]@COMPONENTS.register_module()
class GLDiscs(nn.Module):
"""Discriminators in Global&Local.
This discriminator contains a local discriminator and a global
discriminator as described in the original paper:
Globally and locally Consistent Image Completion
Args:
global_disc_cfg (dict): Config dict to build global discriminator.
local_disc_cfg (dict): Config dict to build local discriminator.
"""
def __init__(self, global_disc_cfg, local_disc_cfg):
super().__init__()
self.global_disc = MultiLayerDiscriminator(**global_disc_cfg)
self.local_disc = MultiLayerDiscriminator(**local_disc_cfg)
self.fc = nn.Linear(2048, 1, bias=True)
[文档] def forward(self, x):
"""Forward function.
Args:
x (tuple[torch.Tensor]): Contains global image and the local image
patch.
Returns:
tuple[torch.Tensor]: Contains the prediction from discriminators \
in global image and local image patch.
"""
g_img, l_img = x
g_pred = self.global_disc(g_img)
l_pred = self.local_disc(l_img)
pred = self.fc(torch.cat([g_pred, l_pred], dim=1))
return pred
[文档] def init_weights(self, pretrained=None):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
# Here, we only initialize the module with fc layer since the
# conv and norm layers has been initialized in `ConvModule`.
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight.data, 0.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
else:
raise TypeError('pretrained must be a str or None')