Shortcuts

mmedit.models.editors.mspie.mspie_stylegan2_discriminator 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn

from mmedit.registry import MODELS
from ..stylegan1 import EqualLinearActModule
from ..stylegan2 import ConvDownLayer, ModMBStddevLayer, ResBlock


@MODELS.register_module()
[文档]class MSStyleGAN2Discriminator(nn.Module): """StyleGAN2 Discriminator. The architecture of this discriminator is proposed in StyleGAN2. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020. Args: in_size (int): The input size of images. channel_multiplier (int, optional): The multiplier factor for the channel number. Defaults to 2. blur_kernel (list, optional): The blurry kernel. Defaults to [1, 3, 3, 1]. mbstd_cfg (dict, optional): Configs for minibatch-stddev layer. Defaults to dict(group_size=4, channel_groups=1). """ def __init__(self, in_size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], mbstd_cfg=dict(group_size=4, channel_groups=1), with_adaptive_pool=False, pool_size=(2, 2)): super().__init__() self.with_adaptive_pool = with_adaptive_pool self.pool_size = pool_size channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * channel_multiplier, 128: 128 * channel_multiplier, 256: 64 * channel_multiplier, 512: 32 * channel_multiplier, 1024: 16 * channel_multiplier, } log_size = int(np.log2(in_size)) in_channels = channels[in_size] convs = [ConvDownLayer(3, channels[in_size], 1)] for i in range(log_size, 2, -1): out_channel = channels[2**(i - 1)] convs.append(ResBlock(in_channels, out_channel, blur_kernel)) in_channels = out_channel self.convs = nn.Sequential(*convs) self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg) self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3) if self.with_adaptive_pool: self.adaptive_pool = nn.AdaptiveAvgPool2d(pool_size) linear_in_channels = channels[4] * pool_size[0] * pool_size[1] else: linear_in_channels = channels[4] * 4 * 4 self.final_linear = nn.Sequential( EqualLinearActModule( linear_in_channels, channels[4], act_cfg=dict(type='fused_bias')), EqualLinearActModule(channels[4], 1), )
[文档] def forward(self, x): """Forward function. Args: x (torch.Tensor): Input image tensor. Returns: torch.Tensor: Predict score for the input image. """ x = self.convs(x) x = self.mbstd_layer(x) x = self.final_conv(x) if self.with_adaptive_pool: x = self.adaptive_pool(x) x = x.view(x.shape[0], -1) x = self.final_linear(x) return x