mmedit.models.editors.biggan.biggan_deep_discriminator 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import mmengine
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import normal_init, xavier_init
from mmengine.runner import load_checkpoint
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
from torch.nn.utils import spectral_norm
from mmedit.registry import MODELS
from .biggan_modules import SelfAttentionBlock, SNConvModule
from .biggan_snmodule import SNEmbedding, SNLinear
@MODELS.register_module()
[文档]class BigGANDeepDiscriminator(nn.Module):
"""BigGAN-Deep Discriminator. The implementation refers to
https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py # noqa.
The overall structure of BigGAN's discriminator is the same with
the projection discriminator.
The main difference between BigGAN and BigGAN-deep is that
BigGAN-deep use more deeper residual blocks to construct the whole
model.
More details can be found in: Large Scale GAN Training for High Fidelity
Natural Image Synthesis (ICLR2019).
The design of the model structure is highly corresponding to the output
resolution. For origin BigGAN-Deep's generator, you can set ``output_scale``
as you need and use the default value of ``arch_cfg`` and ``blocks_cfg``.
If you want to customize the model, you can set the arguments in this way:
``arch_cfg``: Config for the architecture of this generator. You can refer
the ``_default_arch_cfgs`` in the ``_get_default_arch_cfg`` function to see
the format of the ``arch_cfg``. Basically, you need to provide information
of each block such as the numbers of input and output channels, whether to
perform upsampling etc.
``blocks_cfg``: Config for the convolution block. You can adjust block params
like ``channel_ratio`` here. You can also replace the block type
to your registered customized block. However, you should notice that some
params are shared between these blocks like ``act_cfg``, ``with_spectral_norm``,
``sn_eps`` etc.
Args:
input_scale (int): The scale of the input image.
num_classes (int, optional): The number of conditional classes.
Defaults to 0.
in_channels (int, optional): The channel number of the input image.
Defaults to 3.
out_channels (int, optional): The channel number of the final output.
Defaults to 1.
base_channels (int, optional): The basic channel number of the
discriminator. The other layers contains channels based on this
number. Defaults to 96.
block_depth (int, optional): The repeat times of Residual Blocks in
each level of architecture. Defaults to 2.
sn_eps (float, optional): Epsilon value for spectral normalization.
Defaults to 1e-6.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `ajbrock`.
init_type (str, optional): The name of an initialization method:
ortho | N02 | xavier. Defaults to 'ortho'.
act_cfg (dict, optional): Config for the activation layer.
Defaults to dict(type='ReLU').
with_spectral_norm (bool, optional): Whether to use spectral
normalization. Defaults to True.
blocks_cfg (dict, optional): Config for the convolution block.
Defaults to dict(type='BigGANDiscResBlock').
arch_cfg (dict, optional): Config for the architecture of this
discriminator. Defaults to None.
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
"""
def __init__(self,
input_scale,
num_classes=0,
in_channels=3,
out_channels=1,
base_channels=96,
block_depth=2,
sn_eps=1e-6,
sn_style='ajbrock',
init_type='ortho',
act_cfg=dict(type='ReLU', inplace=False),
with_spectral_norm=True,
blocks_cfg=dict(type='BigGANDeepDiscResBlock'),
arch_cfg=None,
pretrained=None):
super().__init__()
self.num_classes = num_classes
self.out_channels = out_channels
self.input_scale = input_scale
self.in_channels = in_channels
self.base_channels = base_channels
self.block_depth = block_depth
self.arch = arch_cfg if arch_cfg else self._get_default_arch_cfg(
self.input_scale, self.base_channels)
self.blocks_cfg = deepcopy(blocks_cfg)
self.blocks_cfg.update(
dict(
act_cfg=act_cfg,
sn_eps=sn_eps,
sn_style=sn_style,
with_spectral_norm=with_spectral_norm))
self.input_conv = SNConvModule(
3,
self.arch['in_channels'][0],
kernel_size=3,
padding=1,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=dict(eps=sn_eps, sn_style=sn_style),
act_cfg=None)
self.conv_blocks = nn.ModuleList()
for index, out_ch in enumerate(self.arch['out_channels']):
for depth in range(self.block_depth):
# change args to adapt to current block
block_cfg_ = deepcopy(self.blocks_cfg)
block_cfg_.update(
dict(
in_channels=self.arch['in_channels'][index]
if depth == 0 else out_ch,
out_channels=out_ch,
with_downsample=self.arch['downsample'][index]
and depth == 0))
self.conv_blocks.append(MODELS.build(block_cfg_))
if self.arch['attention'][index]:
self.conv_blocks.append(
SelfAttentionBlock(
out_ch,
with_spectral_norm=with_spectral_norm,
sn_eps=sn_eps,
sn_style=sn_style))
self.activate = MODELS.build(act_cfg)
self.decision = nn.Linear(self.arch['out_channels'][-1], out_channels)
if with_spectral_norm:
if sn_style == 'torch':
self.decision = spectral_norm(self.decision, eps=sn_eps)
elif sn_style == 'ajbrock':
self.decision = SNLinear(
self.arch['out_channels'][-1], out_channels, eps=sn_eps)
else:
raise NotImplementedError(
f'{sn_style} style SN is not supported yet')
if self.num_classes > 0:
self.proj_y = nn.Embedding(self.num_classes,
self.arch['out_channels'][-1])
if with_spectral_norm:
if sn_style == 'torch':
self.proj_y = spectral_norm(self.proj_y, eps=sn_eps)
elif sn_style == 'ajbrock':
self.proj_y = SNEmbedding(
self.num_classes,
self.arch['out_channels'][-1],
eps=sn_eps)
else:
raise NotImplementedError(
f'{sn_style} style SN is not supported yet')
self.init_weights(pretrained=pretrained, init_type=init_type)
[文档] def _get_default_arch_cfg(self, input_scale, base_channels):
assert input_scale in [32, 64, 128, 256, 512]
_default_arch_cfgs = {
'32': {
'in_channels': [base_channels * item for item in [4, 4, 4]],
'out_channels': [base_channels * item for item in [4, 4, 4]],
'downsample': [True, True, False, False],
'resolution': [16, 8, 8, 8],
'attention': [False, False, False, False]
},
'64': {
'in_channels': [base_channels * item for item in [1, 2, 4, 8]],
'out_channels':
[base_channels * item for item in [2, 4, 8, 16]],
'downsample': [True] * 4 + [False],
'resolution': [32, 16, 8, 4, 4],
'attention': [False, False, False, False, False]
},
'128': {
'in_channels':
[base_channels * item for item in [1, 2, 4, 8, 16]],
'out_channels':
[base_channels * item for item in [2, 4, 8, 16, 16]],
'downsample': [True] * 5 + [False],
'resolution': [64, 32, 16, 8, 4, 4],
'attention': [True, False, False, False, False, False]
},
'256': {
'in_channels':
[base_channels * item for item in [1, 2, 4, 8, 8, 16]],
'out_channels':
[base_channels * item for item in [2, 4, 8, 8, 16, 16]],
'downsample': [True] * 6 + [False],
'resolution': [128, 64, 32, 16, 8, 4, 4],
'attention': [False, True, False, False, False, False]
},
'512': {
'in_channels':
[base_channels * item for item in [1, 1, 2, 4, 8, 8, 16]],
'out_channels':
[base_channels * item for item in [1, 2, 4, 8, 8, 16, 16]],
'downsample': [True] * 7 + [False],
'resolution': [256, 128, 64, 32, 16, 8, 4, 4],
'attention': [False, False, False, True, False, False, False]
}
}
return _default_arch_cfgs[str(input_scale)]
[文档] def forward(self, x, label=None):
"""Forward function.
Args:
x (torch.Tensor): Fake or real image tensor.
label (torch.Tensor | None): Label Tensor. Defaults to None.
Returns:
torch.Tensor: Prediction for the reality of the input image with
given label.
"""
x0 = self.input_conv(x)
for conv_block in self.conv_blocks:
x0 = conv_block(x0)
x0 = self.activate(x0)
x0 = torch.sum(x0, dim=[2, 3])
out = self.decision(x0)
if self.num_classes > 0:
w_y = self.proj_y(label)
out = out + torch.sum(w_y * x0, dim=1, keepdim=True)
return out
[文档] def init_weights(self, pretrained=None, init_type='ortho'):
"""Init weights for models.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
init_type (str, optional): The name of an initialization method:
ortho | N02 | xavier. Defaults to 'ortho'.
"""
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif isinstance(pretrained, dict):
ckpt_path = pretrained.get('ckpt_path', None)
assert ckpt_path is not None
prefix = pretrained.get('prefix', '')
map_location = pretrained.get('map_location', 'cpu')
strict = pretrained.get('strict', True)
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmengine.print_log(f'Load pretrained model from {ckpt_path}')
elif pretrained is None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
if init_type == 'ortho':
nn.init.orthogonal_(m.weight)
elif init_type == 'N02':
normal_init(m, 0.0, 0.02)
elif init_type == 'xavier':
xavier_init(m)
else:
raise NotImplementedError(
f'{init_type} initialization \
not supported now.')
else:
raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')