Shortcuts

mmedit.models.editors.dcgan.dcgan_discriminator 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.logging import MMLogger
from mmengine.model import normal_init
from mmengine.runner import load_checkpoint
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmedit.registry import MODELS


@MODELS.register_module()
[文档]class DCGANDiscriminator(nn.Module): """Discriminator for DCGAN. Implementation Details for DCGAN architecture: #. Adopt convolution in the discriminator; #. Use batchnorm in the discriminator except for the input and final \ output layer; #. Use LeakyReLU in the discriminator in addition to the output layer. Args: input_scale (int): The scale of the input image. output_scale (int): The final scale of the convolutional feature. out_channels (int): The channel number of the final output layer. in_channels (int, optional): The channel number of the input image. Defaults to 3. base_channels (int, optional): The basic channel number of the generator. The other layers contains channels based on this number. Defaults to 128. default_norm_cfg (dict, optional): Norm config for all of layers except for the final output layer. Defaults to ``dict(type='BN')``. default_act_cfg (dict, optional): Activation config for all of layers except for the final output layer. Defaults to ``dict(type='ReLU')``. out_act_cfg (dict, optional): Activation config for the final output layer. Defaults to ``dict(type='Tanh')``. pretrained (str, optional): Path for the pretrained model. Default to ``None``. """ def __init__(self, input_scale, output_scale, out_channels, in_channels=3, base_channels=128, default_norm_cfg=dict(type='BN'), default_act_cfg=dict(type='LeakyReLU'), out_act_cfg=None, pretrained=None): super().__init__() self.input_scale = input_scale self.output_scale = output_scale self.out_channels = out_channels self.base_channels = base_channels # the number of times for downsampling self.num_downsamples = int(np.log2(input_scale // output_scale)) # build up downsampling backbone (excluding the output layer) downsamples = [] curr_channels = in_channels for i in range(self.num_downsamples): # remove norm for the first conv norm_cfg_ = None if i == 0 else default_norm_cfg in_ch = in_channels if i == 0 else base_channels * 2**(i - 1) downsamples.append( ConvModule( in_ch, base_channels * 2**i, kernel_size=4, stride=2, padding=1, conv_cfg=dict(type='Conv2d'), norm_cfg=norm_cfg_, act_cfg=default_act_cfg)) curr_channels = base_channels * 2**i self.downsamples = nn.Sequential(*downsamples) # define output layer self.output_layer = ConvModule( curr_channels, out_channels, kernel_size=4, stride=1, padding=0, conv_cfg=dict(type='Conv2d'), norm_cfg=None, act_cfg=out_act_cfg) self.init_weights(pretrained=pretrained)
[文档] def forward(self, x): """Forward function. Args: x (torch.Tensor): Fake or real image tensor. Returns: torch.Tensor: Prediction for the reality of the input image. """ n = x.shape[0] x = self.downsamples(x) x = self.output_layer(x) # reshape to a flatten feature return x.view(n, -1)
[文档] def init_weights(self, pretrained=None): """Init weights for models. We just use the initialization method proposed in the original paper. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = MMLogger.get_current_instance() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): normal_init(m, 0, 0.02) elif isinstance(m, _BatchNorm): nn.init.normal_(m.weight.data) nn.init.constant_(m.bias.data, 0) else: raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')