mmedit.models.editors.pggan.pggan_generator 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import numpy as np
import torch
import torch.nn as nn
from mmedit.registry import MODELS
from ...utils import get_module_device
from .pggan_modules import (EqualizedLRConvModule, EqualizedLRConvUpModule,
PGGANNoiseTo2DFeat)
@MODELS.register_module()
[文档]class PGGANGenerator(nn.Module):
"""Generator for PGGAN.
Args:
noise_size (int): Size of the input noise vector.
out_scale (int): Output scale for the generated image.
label_size (int, optional): Size of the label vector.
Defaults to 0.
base_channels (int, optional): The basic channel number of the
generator. The other layers contains channels based on this
number. Defaults to 8192.
channel_decay (float, optional): Decay for channels of feature maps.
Defaults to 1.0.
max_channels (int, optional): Maximum channels for the feature
maps in the generator block. Defaults to 512.
fused_upconv (bool, optional): Whether use fused upconv.
Defaults to True.
conv_module_cfg (dict, optional): Config for the convolution
module used in this generator. Defaults to None.
fused_upconv_cfg (dict, optional): Config for the fused upconv
module used in this generator. Defaults to None.
upsample_cfg (dict, optional): Config for the upsampling operation.
Defaults to None.
"""
[文档] _default_fused_upconv_cfg = dict(
conv_cfg=dict(type='deconv'),
kernel_size=3,
stride=2,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='PixelNorm'),
order=('conv', 'act', 'norm'))
[文档] _default_conv_module_cfg = dict(
conv_cfg=None,
kernel_size=3,
stride=1,
padding=1,
bias=True,
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
norm_cfg=dict(type='PixelNorm'),
order=('conv', 'act', 'norm'))
def __init__(self,
noise_size,
out_scale,
label_size=0,
base_channels=8192,
channel_decay=1.,
max_channels=512,
fused_upconv=True,
conv_module_cfg=None,
fused_upconv_cfg=None,
upsample_cfg=None):
super().__init__()
self.noise_size = noise_size if noise_size else min(
base_channels, max_channels)
self.out_scale = out_scale
self.out_log2_scale = int(np.log2(out_scale))
# sanity check for the output scale
assert out_scale == 2**self.out_log2_scale and out_scale >= 4
self.label_size = label_size
self.base_channels = base_channels
self.channel_decay = channel_decay
self.max_channels = max_channels
self.fused_upconv = fused_upconv
# set conv cfg
self.conv_module_cfg = deepcopy(self._default_conv_module_cfg)
# update with customized config
if conv_module_cfg:
self.conv_module_cfg.update(conv_module_cfg)
if self.fused_upconv:
self.fused_upconv_cfg = deepcopy(self._default_fused_upconv_cfg)
# update with customized config
if fused_upconv_cfg:
self.fused_upconv_cfg.update(fused_upconv_cfg)
self.upsample_cfg = deepcopy(self._default_upsample_cfg)
if upsample_cfg is not None:
self.upsample_cfg.update(upsample_cfg)
self.noise2feat = PGGANNoiseTo2DFeat(noise_size + label_size,
self._num_out_channels(1))
self.torgb_layers = nn.ModuleList()
self.conv_blocks = nn.ModuleList()
for s in range(2, self.out_log2_scale + 1):
in_ch = self._num_out_channels(
s - 1) if s == 2 else self._num_out_channels(s - 2)
# setup torgb layers
self.torgb_layers.append(
self._get_torgb_layer(self._num_out_channels(s - 1)))
# setup upconv or conv blocks
self.conv_blocks.extend(self._get_upconv_block(in_ch, s))
# build upsample layer for residual path
self.upsample_layer = MODELS.build(self.upsample_cfg)
[文档] def _get_torgb_layer(self, in_channels: int):
"""Get the to-rgb layer based on `in_channels`.
Args:
in_channels (int): Number of input channels.
Returns:
nn.Module: To-rgb layer.
"""
return EqualizedLRConvModule(
in_channels,
3,
kernel_size=1,
stride=1,
equalized_lr_cfg=dict(gain=1),
bias=True,
norm_cfg=None,
act_cfg=None)
[文档] def _num_out_channels(self, log_scale: int):
"""Calculate the number of output channels based on logarithm of
current scale.
Args:
log_scale (int): The logarithm of the current scale.
Returns:
int: The current number of output channels.
"""
return min(
int(self.base_channels / (2.0**(log_scale * self.channel_decay))),
self.max_channels)
[文档] def _get_upconv_block(self, in_channels, log_scale):
"""Get the conv block for upsampling.
Args:
in_channels (int): The number of input channels.
log_scale (int): The logarithmic of the current scale.
Returns:
nn.Module: The conv block for upsampling.
"""
modules = []
# start 4x4 scale
if log_scale == 2:
modules.append(
EqualizedLRConvModule(in_channels,
self._num_out_channels(log_scale - 1),
**self.conv_module_cfg))
# 8x8 --> 1024x1024 scales
else:
if self.fused_upconv:
cfg_ = dict(upsample=dict(type='fused_nn'))
cfg_.update(self.fused_upconv_cfg)
else:
cfg_ = dict(upsample=self.upsample_cfg)
cfg_.update(self.conv_module_cfg)
# up + conv
modules.append(
EqualizedLRConvUpModule(in_channels,
self._num_out_channels(log_scale - 1),
**cfg_))
# refine conv
modules.append(
EqualizedLRConvModule(
self._num_out_channels(log_scale - 1),
self._num_out_channels(log_scale - 1),
**self.conv_module_cfg))
return modules
[文档] def forward(self,
noise,
label=None,
num_batches=0,
return_noise=False,
transition_weight=1.,
curr_scale=-1):
"""Forward function.
Args:
noise (torch.Tensor | callable | None): You can directly give a
batch of noise through a ``torch.Tensor`` or offer a callable
function to sample a batch of noise data. Otherwise, the
``None`` indicates to use the default noise sampler.
label (Tensor, optional): Label vector with shape [N, C]. Defaults
to None.
num_batches (int, optional): The number of batch size. Defaults to
0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
transition_weight (float, optional): The weight used in resolution
transition. Defaults to 1.0.
curr_scale (int, optional): The scale for the current inference or
training. Defaults to -1.
Returns:
torch.Tensor | dict: If not ``return_noise``, only the output image
will be returned. Otherwise, a dict contains ``fake_img`` and
``noise_batch`` will be returned.
"""
# receive noise and conduct sanity check.
if isinstance(noise, torch.Tensor):
assert noise.shape[1] == self.noise_size
assert noise.ndim == 2, ('The noise should be in shape of (n, c), '
f'but got {noise.shape}')
noise_batch = noise
# receive a noise generator and sample noise.
elif callable(noise):
noise_generator = noise
assert num_batches > 0
noise_batch = noise_generator((num_batches, self.noise_size))
# otherwise, we will adopt default noise sampler.
else:
assert num_batches > 0
# TODO: check pggan default noise type
noise_batch = torch.randn((num_batches, self.noise_size))
# dirty code for putting data on the right device
noise_batch = noise_batch.to(get_module_device(self))
if label is not None:
noise_batch = torch.cat([noise_batch,
label.to(noise_batch)],
dim=1)
# noise vector to 2D feature
x = self.noise2feat(noise_batch)
# build current computational graph
curr_log2_scale = self.out_log2_scale if curr_scale < 0 else int(
np.log2(curr_scale))
# 4x4 scale
x = self.conv_blocks[0](x)
if curr_log2_scale <= 3:
out_img = last_img = self.torgb_layers[0](x)
# 8x8 and larger scales
for s in range(3, curr_log2_scale + 1):
x = self.conv_blocks[2 * s - 5](x)
x = self.conv_blocks[2 * s - 4](x)
if s + 1 == curr_log2_scale:
last_img = self.torgb_layers[s - 2](x)
elif s == curr_log2_scale:
out_img = self.torgb_layers[s - 2](x)
residual_img = self.upsample_layer(last_img)
out_img = residual_img + transition_weight * (
out_img - residual_img)
if return_noise:
output = dict(
fake_img=out_img, noise_batch=noise_batch, label=label)
return output
return out_img