Shortcuts

mmedit.models.editors.lsgan.lsgan_generator 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmedit.registry import MODELS
from ...utils import get_module_device


@MODELS.register_module()
[文档]class LSGANGenerator(nn.Module): """Generator for LSGAN. Implementation Details for LSGAN architecture: #. Adopt transposed convolution in the generator; #. Use batchnorm in the generator except for the final output layer; #. Use ReLU in the generator in addition to the final output layer; #. Keep channels of feature maps unchanged in the convolution backbone; #. Use one more 3x3 conv every upsampling in the convolution backbone. We follow the implementation details of the origin paper: Least Squares Generative Adversarial Networks https://arxiv.org/pdf/1611.04076.pdf Args: output_scale (int, optional): Output scale for the generated image. Defaults to 128. out_channels (int, optional): The channel number of the output feature. Defaults to 3. base_channels (int, optional): The basic channel number of the generator. The other layers contains channels based on this number. Defaults to 256. input_scale (int, optional): The scale of the input 2D feature map. Defaults to 8. noise_size (int, optional): Size of the input noise vector. Defaults to 1024. conv_cfg (dict, optional): Config for the convolution module used in this generator. Defaults to dict(type='ConvTranspose2d'). default_norm_cfg (dict, optional): Norm config for all of layers except for the final output layer. Defaults to dict(type='BN'). default_act_cfg (dict, optional): Activation config for all of layers except for the final output layer. Defaults to dict(type='ReLU'). out_act_cfg (dict, optional): Activation config for the final output layer. Defaults to dict(type='Tanh'). """ def __init__(self, output_scale=128, out_channels=3, base_channels=256, input_scale=8, noise_size=1024, conv_cfg=dict(type='ConvTranspose2d'), default_norm_cfg=dict(type='BN'), default_act_cfg=dict(type='ReLU'), out_act_cfg=dict(type='Tanh')): super().__init__() assert output_scale % input_scale == 0 assert output_scale // input_scale >= 4 self.output_scale = output_scale self.base_channels = base_channels self.input_scale = input_scale self.noise_size = noise_size self.noise2feat_head = nn.Sequential( nn.Linear(noise_size, input_scale * input_scale * base_channels)) self.noise2feat_tail = nn.Sequential(nn.BatchNorm2d(base_channels)) if default_act_cfg is not None: self.noise2feat_tail.add_module('act', MODELS.build(default_act_cfg)) # the number of times for upsampling self.num_upsamples = int(np.log2(output_scale // input_scale)) - 2 # build up convolution backbone (excluding the output layer) self.conv_blocks = nn.ModuleList() for _ in range(self.num_upsamples): self.conv_blocks.append( ConvModule( base_channels, base_channels, kernel_size=3, stride=2, padding=1, conv_cfg=dict(conv_cfg, output_padding=1), norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) self.conv_blocks.append( ConvModule( base_channels, base_channels, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) # output blocks self.conv_blocks.append( ConvModule( base_channels, int(base_channels // 2), kernel_size=3, stride=2, padding=1, conv_cfg=dict(conv_cfg, output_padding=1), norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) self.conv_blocks.append( ConvModule( int(base_channels // 2), int(base_channels // 4), kernel_size=3, stride=2, padding=1, conv_cfg=dict(conv_cfg, output_padding=1), norm_cfg=default_norm_cfg, act_cfg=default_act_cfg)) self.conv_blocks.append( ConvModule( int(base_channels // 4), out_channels, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=None, act_cfg=out_act_cfg))
[文档] def forward(self, noise, num_batches=0, return_noise=False): """Forward function. Args: noise (torch.Tensor | callable | None): You can directly give a batch of noise through a ``torch.Tensor`` or offer a callable function to sample a batch of noise data. Otherwise, the ``None`` indicates to use the default noise sampler. num_batches (int, optional): The number of batch size. Defaults to 0. return_noise (bool, optional): If True, ``noise_batch`` will be returned in a dict with ``fake_img``. Defaults to False. Returns: torch.Tensor | dict: If not ``return_noise``, only the output image will be returned. Otherwise, a dict contains ``fake_img`` and ``noise_batch`` will be returned. """ # receive noise and conduct sanity check. if isinstance(noise, torch.Tensor): assert noise.shape[1] == self.noise_size if noise.ndim == 2: noise_batch = noise else: raise ValueError('The noise should be in shape of (n, c)' f'but got {noise.shape}') # receive a noise generator and sample noise. elif callable(noise): noise_generator = noise assert num_batches > 0 noise_batch = noise_generator((num_batches, self.noise_size)) # otherwise, we will adopt default noise sampler. else: assert num_batches > 0 noise_batch = torch.randn((num_batches, self.noise_size)) # dirty code for putting data on the right device noise_batch = noise_batch.to(get_module_device(self)) # noise2feat x = self.noise2feat_head(noise_batch) x = x.reshape( (-1, self.base_channels, self.input_scale, self.input_scale)) x = self.noise2feat_tail(x) # conv module for conv in self.conv_blocks: x = conv(x) if return_noise: return dict(fake_img=x, noise_batch=noise_batch) return x