mmedit.models.editors.biggan.biggan_deep_generator 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import mmengine
import torch
import torch.nn as nn
from mmengine.logging import MMLogger
from mmengine.model import normal_init, xavier_init
from mmengine.runner import load_checkpoint
from mmengine.runner.checkpoint import _load_checkpoint_with_prefix
from torch.nn.utils import spectral_norm
from mmedit.registry import MODELS
from ...utils import get_module_device
from .biggan_modules import SelfAttentionBlock, SNConvModule
from .biggan_snmodule import SNLinear
@MODELS.register_module()
[文档]class BigGANDeepGenerator(nn.Module):
"""BigGAN-Deep Generator. The implementation refers to
https://github.com/ajbrock/BigGAN-PyTorch/blob/master/BigGANdeep.py # noqa.
In BigGAN, we use a SAGAN-based architecture composing of an
self-attention block and number of convolutional residual blocks
with spectral normalization. BigGAN-deep follow the same architecture.
The main difference between BigGAN and BigGAN-deep is that
BigGAN-deep uses deeper residual blocks to construct the whole
model.
More details can be found in: Large Scale GAN Training for High Fidelity
Natural Image Synthesis (ICLR2019).
The design of the model structure is highly corresponding to the output
resolution. For the original BigGAN-Deep's generator, you can set ``output_scale``
as you need and use the default value of ``arch_cfg`` and ``blocks_cfg``.
If you want to customize the model, you can set the arguments in this way:
``arch_cfg``: Config for the architecture of this generator. You can refer
the ``_default_arch_cfgs`` in the ``_get_default_arch_cfg`` function to see
the format of the ``arch_cfg``. Basically, you need to provide information
of each block such as the numbers of input and output channels, whether to
perform upsampling, etc.
``blocks_cfg``: Config for the convolution block. You can adjust block params
like ``channel_ratio`` here. You can also replace the block type
to your registered customized block. However, you should notice that some
params are shared among these blocks like ``act_cfg``, ``with_spectral_norm``,
``sn_eps``, etc.
Args:
output_scale (int): Output scale for the generated image.
noise_size (int, optional): Size of the input noise vector. Defaults
to 120.
num_classes (int, optional): The number of conditional classes. If set
to 0, this model will be degraded to an unconditional model.
Defaults to 0.
out_channels (int, optional): Number of channels in output images.
Defaults to 3.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this number.
Defaults to 96.
block_depth (int, optional): The repeat times of Residual Blocks in
each level of architecture. Defaults to 2.
input_scale (int, optional): The scale of the input 2D feature map.
Defaults to 4.
with_shared_embedding (bool, optional): Whether to use shared
embedding. Defaults to True.
shared_dim (int, optional): The output channels of shared embedding.
Defaults to 128.
sn_eps (float, optional): Epsilon value for spectral normalization.
Defaults to 1e-6.
sn_style (str, optional): The style of spectral normalization.
If set to `ajbrock`, implementation by
ajbrock(https://github.com/ajbrock/BigGAN-PyTorch/blob/master/layers.py)
will be adopted.
If set to `torch`, implementation by `PyTorch` will be adopted.
Defaults to `ajbrock`.
init_type (str, optional): The name of an initialization method:
ortho | N02 | xavier. Defaults to 'ortho'.
concat_noise (bool, optional): Whether to concat input noise vector
with class vector. Defaults to True.
act_cfg (dict, optional): Config for the activation layer. Defaults to
dict(type='ReLU').
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to dict(type='nearest', scale_factor=2).
with_spectral_norm (bool, optional): Whether to use spectral
normalization. Defaults to True.
auto_sync_bn (bool, optional): Whether to use synchronized batch
normalization. Defaults to True.
blocks_cfg (dict, optional): Config for the convolution block. Defaults
to dict(type='BigGANGenResBlock').
arch_cfg (dict, optional): Config for the architecture of this
generator. Defaults to None.
out_norm_cfg (dict, optional): Config for the norm of output layer.
Defaults to dict(type='BN').
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose necessary
key is 'ckpt_path'. Besides, you can also provide 'prefix' to load
the generator part from the whole state dict. Defaults to None.
rgb2bgr (bool, optional): Whether to reformat the output channels
with order `bgr`. We provide several pre-trained BigGAN-Deep
weights whose output channels order is `rgb`. You can set
this argument to True to use the weights.
"""
def __init__(self,
output_scale,
noise_size=120,
num_classes=0,
out_channels=3,
base_channels=96,
block_depth=2,
input_scale=4,
with_shared_embedding=True,
shared_dim=128,
sn_eps=1e-6,
sn_style='ajbrock',
init_type='ortho',
concat_noise=True,
act_cfg=dict(type='ReLU', inplace=False),
upsample_cfg=dict(type='nearest', scale_factor=2),
with_spectral_norm=True,
auto_sync_bn=True,
blocks_cfg=dict(type='BigGANDeepGenResBlock'),
arch_cfg=None,
out_norm_cfg=dict(type='BN'),
pretrained=None,
rgb2bgr=False):
super().__init__()
self.noise_size = noise_size
self.num_classes = num_classes
self.shared_dim = shared_dim
self.with_shared_embedding = with_shared_embedding
self.output_scale = output_scale
self.arch = arch_cfg if arch_cfg else self._get_default_arch_cfg(
self.output_scale, base_channels)
self.input_scale = input_scale
self.concat_noise = concat_noise
self.blocks_cfg = deepcopy(blocks_cfg)
self.upsample_cfg = deepcopy(upsample_cfg)
self.block_depth = block_depth
self.rgb2bgr = rgb2bgr
self.sn_style = sn_style
# Validity Check
# If 'num_classes' equals to zero, we shall set 'with_shared_embedding'
# to False.
if num_classes == 0:
assert not self.with_shared_embedding
assert not self.concat_noise
elif not self.with_shared_embedding:
# If not `with_shared_embedding`, we will use `nn.Embedding` to
# replace the original `Linear` layer in conditional BN.
# Meanwhile, we do not adopt split noises.
assert not self.concat_noise
# First linear layer
if self.concat_noise:
self.noise2feat = nn.Linear(
self.noise_size + self.shared_dim,
self.arch['in_channels'][0] * (self.input_scale**2))
else:
self.noise2feat = nn.Linear(
self.noise_size,
self.arch['in_channels'][0] * (self.input_scale**2))
if with_spectral_norm:
if sn_style == 'torch':
self.noise2feat = spectral_norm(self.noise2feat, eps=sn_eps)
elif sn_style == 'ajbrock':
self.noise2feat = SNLinear(
self.noise_size +
(self.shared_dim if self.concat_noise else 0),
self.arch['in_channels'][0] * (self.input_scale**2),
eps=sn_eps)
else:
NotImplementedError(f'{sn_style} style SN is not supported')
# If using 'shared_embedding', we will get an unified embedding of
# label for all blocks. If not, we just pass the label to each
# block.
if with_shared_embedding:
self.shared_embedding = nn.Embedding(num_classes, shared_dim)
else:
self.shared_embedding = nn.Identity()
if num_classes > 0:
if self.concat_noise:
self.dim_after_concat = (
self.shared_dim + self.noise_size
if self.with_shared_embedding else self.num_classes)
else:
self.dim_after_concat = (
self.shared_dim
if self.with_shared_embedding else self.num_classes)
else:
self.dim_after_concat = 0
self.blocks_cfg.update(
dict(
dim_after_concat=self.dim_after_concat,
act_cfg=act_cfg,
sn_eps=sn_eps,
sn_style=sn_style,
input_is_label=(num_classes > 0)
and (not with_shared_embedding),
with_spectral_norm=with_spectral_norm,
auto_sync_bn=auto_sync_bn))
self.conv_blocks = nn.ModuleList()
for index, out_ch in enumerate(self.arch['out_channels']):
for depth in range(self.block_depth):
# change args to adapt to current block
block_cfg_ = deepcopy(self.blocks_cfg)
block_cfg_.update(
dict(
in_channels=self.arch['in_channels'][index],
out_channels=out_ch if depth == (self.block_depth - 1)
else self.arch['in_channels'][index],
upsample_cfg=self.upsample_cfg
if self.arch['upsample'][index]
and depth == (self.block_depth - 1) else None))
self.conv_blocks.append(MODELS.build(block_cfg_))
if self.arch['attention'][index]:
self.conv_blocks.append(
SelfAttentionBlock(
out_ch,
with_spectral_norm=with_spectral_norm,
sn_eps=sn_eps,
sn_style=sn_style))
self.output_layer = SNConvModule(
self.arch['out_channels'][-1],
out_channels,
kernel_size=3,
padding=1,
with_spectral_norm=with_spectral_norm,
spectral_norm_cfg=dict(eps=sn_eps, sn_style=sn_style),
act_cfg=act_cfg,
norm_cfg=out_norm_cfg,
bias=True,
order=('norm', 'act', 'conv'))
self.init_weights(pretrained=pretrained, init_type=init_type)
[文档] def _get_default_arch_cfg(self, output_scale, base_channels):
assert output_scale in [32, 64, 128, 256, 512]
_default_arch_cfgs = {
'32': {
'in_channels': [base_channels * item for item in [4, 4, 4]],
'out_channels': [base_channels * item for item in [4, 4, 4]],
'upsample': [True] * 3,
'resolution': [8, 16, 32],
'attention': [False, False, False]
},
'64': {
'in_channels':
[base_channels * item for item in [16, 16, 8, 4]],
'out_channels':
[base_channels * item for item in [16, 8, 4, 2]],
'upsample': [True] * 4,
'resolution': [8, 16, 32, 64],
'attention': [False, False, False, True]
},
'128': {
'in_channels':
[base_channels * item for item in [16, 16, 8, 4, 2]],
'out_channels':
[base_channels * item for item in [16, 8, 4, 2, 1]],
'upsample': [True] * 5,
'resolution': [8, 16, 32, 64, 128],
'attention': [False, False, False, True, False]
},
'256': {
'in_channels':
[base_channels * item for item in [16, 16, 8, 8, 4, 2]],
'out_channels':
[base_channels * item for item in [16, 8, 8, 4, 2, 1]],
'upsample': [True] * 6,
'resolution': [8, 16, 32, 64, 128, 256],
'attention': [False, False, False, True, False, False]
},
'512': {
'in_channels':
[base_channels * item for item in [16, 16, 8, 8, 4, 2, 1]],
'out_channels':
[base_channels * item for item in [16, 8, 8, 4, 2, 1, 1]],
'upsample': [True] * 7,
'resolution': [8, 16, 32, 64, 128, 256, 512],
'attention': [False, False, False, True, False, False, False]
}
}
return _default_arch_cfgs[str(output_scale)]
[文档] def forward(self,
noise,
label=None,
num_batches=0,
return_noise=False,
truncation=-1.0,
use_outside_embedding=False):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
label (torch.Tensor | callable | None): You can directly give a
batch of label through a ``torch.Tensor`` or offer a callable
function to sample a batch of label data. Otherwise, the
``None`` indicates to use the default label sampler.
Defaults to None.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` and
``label`` will be returned in a dict with ``fake_img``.
Defaults to False.
truncation (float, optional): Truncation factor. Give value not
less than 0., the truncation trick will be adopted.
Otherwise, the truncation trick will not be adopted.
Defaults to -1..
use_outside_embedding (bool, optional): Whether to use outside
embedding or use `shared_embedding`. Set to `True` if
embedding has already be performed outside this function.
Default to False.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img``,
``noise_batch`` and ``label`` will be returned.
"""
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
noise_batch = torch.randn((num_batches, self.noise_size))
# perform truncation
if truncation >= 0.0:
noise_batch = torch.clamp(noise_batch, -1. * truncation,
1. * truncation)
if self.num_classes == 0:
label_batch = None
elif isinstance(label, torch.Tensor):
if not use_outside_embedding:
assert label.ndim == 1, (
'The label shoube be in shape of (n, )'
f'but got {label.shape}.')
label_batch = label
elif callable(label):
label_generator = label
assert num_batches > 0
label_batch = label_generator((num_batches, ))
else:
assert num_batches > 0
label_batch = torch.randint(0, self.num_classes, (num_batches, ))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
if label_batch is not None:
label_batch = label_batch.to(get_module_device(self))
if not use_outside_embedding:
class_vector = self.shared_embedding(label_batch)
else:
class_vector = label_batch
else:
class_vector = None
# If 'concat noise', concat class vector and noise batch
if self.concat_noise:
if class_vector is not None:
z = torch.cat([noise_batch, class_vector], dim=1)
y = z
elif self.num_classes > 0:
z = noise_batch
y = class_vector
else:
z = noise_batch
y = None
# First linear layer
x = self.noise2feat(z)
# Reshape
# We use this conversion step to allow for loading TF weights
# TF convention on shape is [batch, height, width, channels]
# PT convention on shape is [batch, channels, height, width]
x = x.view(x.size(0), self.input_scale, self.input_scale, -1)
x = x.permute(0, 3, 1, 2).contiguous()
# Loop over blocks
for idx, conv_block in enumerate(self.conv_blocks):
if isinstance(conv_block, SelfAttentionBlock):
x = conv_block(x)
else:
x = conv_block(x, y)
# Apply batchnorm-relu-conv-tanh at output
x = self.output_layer(x)
out_img = torch.tanh(x)
if self.rgb2bgr:
out_img = out_img[:, [2, 1, 0], ...]
if return_noise:
output = dict(
fake_img=out_img, noise_batch=noise_batch, label=label_batch)
return output
return out_img
[文档] def init_weights(self, pretrained=None, init_type='ortho'):
"""Init weights for models.
Args:
pretrained (str | dict, optional): Path for the pretrained model or
dict containing information for pretained models whose
necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
init_type (str, optional): The name of an initialization method:
ortho | N02 | xavier. Defaults to 'ortho'.
"""
if isinstance(pretrained, str):
logger = MMLogger.get_current_instance()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif isinstance(pretrained, dict):
ckpt_path = pretrained.get('ckpt_path', None)
assert ckpt_path is not None
prefix = pretrained.get('prefix', '')
map_location = pretrained.get('map_location', 'cpu')
strict = pretrained.get('strict', True)
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmengine.print_log(f'Load pretrained model from {ckpt_path}')
elif pretrained is None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.Embedding)):
if init_type == 'ortho':
nn.init.orthogonal_(m.weight)
elif init_type == 'N02':
normal_init(m, 0.0, 0.02)
elif init_type == 'xavier':
xavier_init(m)
else:
raise NotImplementedError(
f'{init_type} initialization \
not supported now.')
else:
raise TypeError('pretrained must be a str or None but'
f' got {type(pretrained)} instead.')