mmedit.models.editors.pggan.pggan_discriminator 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mmedit.registry import MODELS
from .pggan_modules import (EqualizedLRConvDownModule, EqualizedLRConvModule,
MiniBatchStddevLayer, PGGANDecisionHead)
@MODELS.register_module()
[文档]class PGGANDiscriminator(nn.Module):
"""Discriminator for PGGAN.
Args:
in_scale (int): The scale of the input image.
label_size (int, optional): Size of the label vector. Defaults to
0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this
number. Defaults to 8192.
max_channels (int, optional): Maximum channels for the feature
maps in the discriminator block. Defaults to 512.
in_channels (int, optional): Number of channels in input images.
Defaults to 3.
channel_decay (float, optional): Decay for channels of feature
maps. Defaults to 1.0.
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4).
fused_convdown (bool, optional): Whether use fused downconv.
Defaults to True.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
fused_convdown_cfg (dict, optional): Config for the fused downconv
module used in this discriminator. Defaults to None.
fromrgb_layer_cfg (dict, optional): Config for the fromrgb layer.
Defaults to None.
downsample_cfg (dict, optional): Config for the downsampling
operation. Defaults to None.
"""
[文档] _default_fromrgb_cfg = dict(
conv_cfg=None,
kernel_size=1,
stride=1,
padding=0,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=None,
order=('conv', 'act', 'norm'))
[文档] _default_conv_module_cfg = dict(
kernel_size=3,
padding=1,
stride=1,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
[文档] _default_convdown_cfg = dict(
kernel_size=3,
padding=1,
stride=2,
norm_cfg=None,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2))
def __init__(self,
in_scale,
label_size=0,
base_channels=8192,
max_channels=512,
in_channels=3,
channel_decay=1.0,
mbstd_cfg=dict(group_size=4),
fused_convdown=True,
conv_module_cfg=None,
fused_convdown_cfg=None,
fromrgb_layer_cfg=None,
downsample_cfg=None):
super().__init__()
self.in_scale = in_scale
self.in_log2_scale = int(np.log2(self.in_scale))
self.label_size = label_size
self.base_channels = base_channels
self.max_channels = max_channels
self.in_channels = in_channels
self.channel_decay = channel_decay
self.with_mbstd = mbstd_cfg is not None
self.fused_convdown = fused_convdown
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
if conv_module_cfg is not None:
self.conv_module_cfg.update(conv_module_cfg)
if self.fused_convdown:
self.fused_convdown_cfg = deepcopy(self._default_convdown_cfg)
if fused_convdown_cfg is not None:
self.fused_convdown_cfg.update(fused_convdown_cfg)
self.fromrgb_layer_cfg = deepcopy(self._default_fromrgb_cfg)
if fromrgb_layer_cfg:
self.fromrgb_layer_cfg.update(fromrgb_layer_cfg)
# setup conv blocks
self.conv_blocks = nn.ModuleList()
self.fromrgb_layers = nn.ModuleList()
for s in range(2, self.in_log2_scale + 1):
self.fromrgb_layers.append(
self._get_fromrgb_layer(self.in_channels, s))
self.conv_blocks.extend(
self._get_convdown_block(self._num_out_channels(s - 1), s))
# setup downsample layer
self.downsample_cfg = deepcopy(downsample_cfg)
if self.downsample_cfg is None or self.downsample_cfg.get(
'type', None) == 'avgpool':
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2)
elif self.downsample_cfg.get('type', None) in ['nearest', 'bilinear']:
self.downsample = partial(
F.interpolate,
mode=self.downsample_cfg.pop('type'),
**self.downsample_cfg)
else:
raise NotImplementedError(
'We have not supported the downsampling with type'
f' {downsample_cfg}.')
# setup minibatch stddev layer
if self.with_mbstd:
self.mbstd_layer = MiniBatchStddevLayer(**mbstd_cfg)
# minibatch stddev layer will concatenate an additional feature map
# in channel dimension.
decision_in_channels = self._num_out_channels(1) * 16 + 16
else:
decision_in_channels = self._num_out_channels(1) * 16
# setup decision layer
self.decision = PGGANDecisionHead(decision_in_channels,
self._num_out_channels(0),
1 + self.label_size)
[文档] def _num_out_channels(self, log_scale: int) -> int:
"""Calculate the number of output channels of the current network from
logarithm of current scale.
Args:
log_scale (int): The logarithm of the current scale.
Returns:
int: The number of output channels.
"""
return min(
int(self.base_channels / (2.0**(log_scale * self.channel_decay))),
self.max_channels)
[文档] def _get_fromrgb_layer(self, in_channels: int,
log2_scale: int) -> nn.Module:
"""Get the 'fromrgb' layer from logarithm of current scale.
Args:
in_channels (int): The number of input channels.
log2_scale (int): The logarithm of the current scale.
Returns:
nn.Module: The built from-rgb layer.
"""
return EqualizedLRConvModule(in_channels,
self._num_out_channels(log2_scale - 1),
**self.fromrgb_layer_cfg)
[文档] def _get_convdown_block(self, in_channels: int,
log2_scale: int) -> nn.Module:
"""Get the downsample layer from logarithm of current scale.
Args:
in_channels (int): The number of input channels.
log2_scale (int): The logarithm of the current scale.
Returns:
nn.Module: The built Conv layer.
"""
modules = []
if log2_scale == 2:
modules.append(
EqualizedLRConvModule(in_channels,
self._num_out_channels(log2_scale - 1),
**self.conv_module_cfg))
else:
modules.append(
EqualizedLRConvModule(in_channels,
self._num_out_channels(log2_scale - 1),
**self.conv_module_cfg))
if self.fused_convdown:
cfg_ = dict(downsample=dict(type='fused_pool'))
cfg_.update(self.fused_convdown_cfg)
else:
cfg_ = dict(downsample=self.downsample)
cfg_.update(self.conv_module_cfg)
modules.append(
EqualizedLRConvDownModule(
self._num_out_channels(log2_scale - 1),
self._num_out_channels(log2_scale - 2), **cfg_))
return modules
[文档] def forward(self, x, transition_weight=1., curr_scale=-1):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1.0.
curr_scale (int, optional): The scale for the current inference or
training. Defaults to -1.
Returns:
Tensor: Predict score for the input image.
"""
curr_log2_scale = self.in_log2_scale if curr_scale < 4 else int(
np.log2(curr_scale))
original_img = x
x = self.fromrgb_layers[curr_log2_scale - 2](x)
for s in range(curr_log2_scale, 2, -1):
x = self.conv_blocks[2 * s - 5](x)
x = self.conv_blocks[2 * s - 4](x)
if s == curr_log2_scale:
img_down = self.downsample(original_img)
y = self.fromrgb_layers[curr_log2_scale - 3](img_down)
x = y + transition_weight * (x - y)
if self.with_mbstd:
x = self.mbstd_layer(x)
x = self.decision(x)
if self.label_size > 0:
return x[:, :1], x[:, 1:]
return x