注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.models.components.stylegan2.generator_discriminator 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import random
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner.checkpoint import _load_checkpoint_with_prefix
from mmedit.models.registry import COMPONENTS
from .common import get_mean_latent, get_module_device, style_mixing
from .modules import (ConstantInput, ConvDownLayer, EqualLinearActModule,
ModMBStddevLayer, ModulatedStyleConv, ModulatedToRGB,
PixelNorm, ResBlock)
[文档]@COMPONENTS.register_module()
class StyleGANv2Generator(nn.Module):
r"""StyleGAN2 Generator.
This module comes from MMGeneration. In the future, this code will be
removed and StyleGANv2 will be directly imported from mmgeneration.
In StyleGAN2, we use a static architecture composing of a style mapping
module and number of convolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- styelgan2-ffhq-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- styelgan2-cat-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='generator_ema'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path. If you just want to load the original
generator (not the ema model), please set the prefix with 'generator'.
Note that our implementation allows to generate BGR image, while the
original StyleGAN2 outputs RGB images by default. Thus, we provide
``bgr2rgb`` argument to convert the image space.
Args:
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probability. The value should be
in range of [0, 1]. Defaults to 0.9.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
bgr2rgb (bool, optional): Whether to flip the image channel dimension.
Defaults to False.
"""
def __init__(self,
out_size,
style_channels,
num_mlps=8,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9,
pretrained=None,
bgr2rgb=False):
super(StyleGANv2Generator, self).__init__()
self.out_size = out_size
self.style_channels = style_channels
self.num_mlps = num_mlps
self.channel_multiplier = channel_multiplier
self.lr_mlp = lr_mlp
self._default_style_mode = default_style_mode
self.default_style_mode = default_style_mode
self.eval_style_mode = eval_style_mode
self.mix_prob = mix_prob
self.bgr2rgb = bgr2rgb
# define style mapping layers
mapping_layers = [PixelNorm()]
for _ in range(num_mlps):
mapping_layers.append(
EqualLinearActModule(
style_channels,
style_channels,
equalized_lr_cfg=dict(lr_mul=lr_mlp, gain=1.),
act_cfg=dict(type='fused_bias')))
self.style_mapping = nn.Sequential(*mapping_layers)
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
# constant input layer
self.constant_input = ConstantInput(self.channels[4])
# 4x4 stage
self.conv1 = ModulatedStyleConv(
self.channels[4],
self.channels[4],
kernel_size=3,
style_channels=style_channels,
blur_kernel=blur_kernel)
self.to_rgb1 = ModulatedToRGB(
self.channels[4], style_channels, upsample=False)
# generator backbone (8x8 --> higher resolutions)
self.log_size = int(np.log2(self.out_size))
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
in_channels_ = self.channels[4]
for i in range(3, self.log_size + 1):
out_channels_ = self.channels[2**i]
self.convs.append(
ModulatedStyleConv(
in_channels_,
out_channels_,
3,
style_channels,
upsample=True,
blur_kernel=blur_kernel))
self.convs.append(
ModulatedStyleConv(
out_channels_,
out_channels_,
3,
style_channels,
upsample=False,
blur_kernel=blur_kernel))
self.to_rgbs.append(
ModulatedToRGB(out_channels_, style_channels, upsample=True))
in_channels_ = out_channels_
self.num_latents = self.log_size * 2 - 2
self.num_injected_noises = self.num_latents - 1
# register buffer for injected noises
for layer_idx in range(self.num_injected_noises):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
self.register_buffer(f'injected_noise_{layer_idx}',
torch.randn(*shape))
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmedit')
[文档] def train(self, mode=True):
if mode:
if self.default_style_mode != self._default_style_mode:
mmcv.print_log(
f'Switch to train style mode: {self._default_style_mode}',
'mmgen')
self.default_style_mode = self._default_style_mode
else:
if self.default_style_mode != self.eval_style_mode:
mmcv.print_log(
f'Switch to evaluation style mode: {self.eval_style_mode}',
'mmgen')
self.default_style_mode = self.eval_style_mode
return super(StyleGANv2Generator, self).train(mode)
def make_injected_noise(self):
device = get_module_device(self)
noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
return noises
def get_mean_latent(self, num_samples=4096, **kwargs):
return get_mean_latent(self, num_samples, **kwargs)
def style_mixing(self,
n_source,
n_target,
inject_index=1,
truncation_latent=None,
truncation=0.7):
return style_mixing(
self,
n_source=n_source,
n_target=n_target,
inject_index=inject_index,
truncation=truncation,
truncation_latent=truncation_latent,
style_channels=self.style_channels)
[文档] def forward(self,
styles,
num_batches=-1,
return_noise=False,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
injected_noise=None,
randomize_noise=True):
"""Forward function.
This function has been integrated with the truncation trick. Please
refer to the usage of `truncation` and `truncation_latent`.
Args:
styles (torch.Tensor | list[torch.Tensor] | callable | None): In
StyleGAN2, you can provide noise tensor or latent tensor. Given
a list containing more than one noise or latent tensors, style
mixing trick will be used in training. Of course, You can
directly give a batch of noise through a ``torch.Tensor`` or
offer a callable function to sample a batch of noise data.
Otherwise, the ``None`` indicates to use the default noise
sampler.
num_batches (int, optional): The number of batch size.
Defaults to 0.
return_noise (bool, optional): If True, ``noise_batch`` will be
returned in a dict with ``fake_img``. Defaults to False.
return_latents (bool, optional): If True, ``latent`` will be
returned in a dict with ``fake_img``. Defaults to False.
inject_index (int | None, optional): The index number for mixing
style codes. Defaults to None.
truncation (float, optional): Truncation factor. Give value less
than 1., the truncation trick will be adopted. Defaults to 1.
truncation_latent (torch.Tensor, optional): Mean truncation latent.
Defaults to None.
input_is_latent (bool, optional): If `True`, the input tensor is
the latent tensor. Defaults to False.
injected_noise (torch.Tensor | None, optional): Given a tensor, the
random noise will be fixed as this input injected noise.
Defaults to None.
randomize_noise (bool, optional): If `False`, images are sampled
with the buffered noise tensor injected to the style conv
block. Defaults to True.
Returns:
torch.Tensor | dict: Generated image tensor or dictionary
containing more data.
"""
# receive noise and conduct sanity check.
if isinstance(styles, torch.Tensor):
assert styles.shape[1] == self.style_channels
styles = [styles]
elif mmcv.is_seq_of(styles, torch.Tensor):
for t in styles:
assert t.shape[-1] == self.style_channels
# receive a noise generator and sample noise.
elif callable(styles):
device = get_module_device(self)
noise_generator = styles
assert num_batches > 0
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
noise_generator((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [noise_generator((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
# otherwise, we will adopt default noise sampler.
else:
device = get_module_device(self)
assert num_batches > 0 and not input_is_latent
if self.default_style_mode == 'mix' and random.random(
) < self.mix_prob:
styles = [
torch.randn((num_batches, self.style_channels))
for _ in range(2)
]
else:
styles = [torch.randn((num_batches, self.style_channels))]
styles = [s.to(device) for s in styles]
if not input_is_latent:
noise_batch = styles
styles = [self.style_mapping(s) for s in styles]
else:
noise_batch = None
if injected_noise is None:
if randomize_noise:
injected_noise = [None] * self.num_injected_noises
else:
injected_noise = [
getattr(self, f'injected_noise_{i}')
for i in range(self.num_injected_noises)
]
# use truncation trick
if truncation < 1:
style_t = []
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
styles = style_t
# no style mixing
if len(styles) < 2:
inject_index = self.num_latents
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
# style mixing
else:
if inject_index is None:
inject_index = random.randint(1, self.num_latents - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.num_latents - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
# 4x4 stage
out = self.constant_input(latent)
out = self.conv1(out, latent[:, 0], noise=injected_noise[0])
skip = self.to_rgb1(out, latent[:, 1])
_index = 1
# 8x8 ---> higher resolutions
for up_conv, conv, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], injected_noise[1::2],
injected_noise[2::2], self.to_rgbs):
out = up_conv(out, latent[:, _index], noise=noise1)
out = conv(out, latent[:, _index + 1], noise=noise2)
skip = to_rgb(out, latent[:, _index + 2], skip)
_index += 2
img = skip
if self.bgr2rgb:
img = torch.flip(img, dims=1)
if return_latents or return_noise:
output_dict = dict(
fake_img=img,
latent=latent,
inject_index=inject_index,
noise_batch=noise_batch)
return output_dict
else:
return img
[文档]@COMPONENTS.register_module()
class StyleGAN2Discriminator(nn.Module):
"""StyleGAN2 Discriminator.
This module comes from MMGeneration. In the future, this code will be
removed and StyleGANv2 will be directly imported from mmgeneration.
The architecture of this discriminator is proposed in StyleGAN2. More
details can be found in: Analyzing and Improving the Image Quality of
StyleGAN CVPR2020.
You can load pretrained model through passing information into
``pretrained`` argument. We have already offered official weights as
follows:
- styelgan2-ffhq-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- styelgan2-cat-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa
If you want to load the ema model, you can just use following codes:
.. code-block:: python
# ckpt_http is one of the valid path from http source
discriminator = StyleGAN2Discriminator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='discriminator'))
Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path.
Note that our implementation adopts BGR image as input, while the
original StyleGAN2 provides RGB images to the discriminator. Thus, we
provide ``bgr2rgb`` argument to convert the image space. If your images
follow the RGB order, please set it to ``True`` accordingly.
Args:
in_size (int): The input size of images.
channel_multiplier (int, optional): The multiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
mbstd_cfg (dict, optional): Configs for minibatch-stddev layer.
Defaults to dict(group_size=4, channel_groups=1).
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
bgr2rgb (bool, optional): Whether to flip the image channel dimension.
Defaults to False.
"""
def __init__(self,
in_size,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
mbstd_cfg=dict(group_size=4, channel_groups=1),
pretrained=None,
bgr2rgb=False):
super(StyleGAN2Discriminator, self).__init__()
self.bgr2rgb = bgr2rgb
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
log_size = int(np.log2(in_size))
in_channels = channels[in_size]
convs = [ConvDownLayer(3, channels[in_size], 1)]
for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]
convs.append(ResBlock(in_channels, out_channel, blur_kernel))
in_channels = out_channel
self.convs = nn.Sequential(*convs)
self.mbstd_layer = ModMBStddevLayer(**mbstd_cfg)
self.final_conv = ConvDownLayer(in_channels + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinearActModule(
channels[4] * 4 * 4,
channels[4],
act_cfg=dict(type='fused_bias')),
EqualLinearActModule(channels[4], 1),
)
if pretrained is not None:
self._load_pretrained_model(**pretrained)
def _load_pretrained_model(self,
ckpt_path,
prefix='',
map_location='cpu',
strict=True):
state_dict = _load_checkpoint_with_prefix(prefix, ckpt_path,
map_location)
self.load_state_dict(state_dict, strict=strict)
mmcv.print_log(f'Load pretrained model from {ckpt_path}', 'mmedit')
[文档] def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input image tensor.
Returns:
torch.Tensor: Predict score for the input image.
"""
if self.bgr2rgb:
x = torch.flip(x, dims=1)
x = self.convs(x)
x = self.mbstd_layer(x)
x = self.final_conv(x)
x = x.view(x.shape[0], -1)
x = self.final_linear(x)
return x