Shortcuts

mmedit.models.editors.cyclegan.cyclegan_modules 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule


[文档]class ResidualBlockWithDropout(nn.Module): """Define a Residual Block with dropout layers. Ref: Deep Residual Learning for Image Recognition A residual block is a conv block with skip connections. A dropout layer is added between two common conv modules. Args: channels (int): Number of channels in the conv layer. padding_mode (str): The name of padding layer: 'reflect' | 'replicate' | 'zeros'. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='IN')`. use_dropout (bool): Whether to use dropout layers. Default: True. """ def __init__(self, channels, padding_mode, norm_cfg=dict(type='BN'), use_dropout=True): super().__init__() assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # We use norm layers in the residual block with dropout layers. # Only for IN, use bias to follow cyclegan's original implementation. use_bias = norm_cfg['type'] == 'IN' block = [ ConvModule( in_channels=channels, out_channels=channels, kernel_size=3, padding=1, bias=use_bias, norm_cfg=norm_cfg, padding_mode=padding_mode) ] if use_dropout: block += [nn.Dropout(0.5)] block += [ ConvModule( in_channels=channels, out_channels=channels, kernel_size=3, padding=1, bias=use_bias, norm_cfg=norm_cfg, act_cfg=None, padding_mode=padding_mode) ] self.block = nn.Sequential(*block)
[文档] def forward(self, x): """Forward function. Add skip connections without final ReLU. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ out = x + self.block(x) return out
[文档]class GANImageBuffer: """This class implements an image buffer that stores previously generated images. This buffer allows us to update the discriminator using a history of generated images rather than the ones produced by the latest generator to reduce model oscillation. Args: buffer_size (int): The size of image buffer. If buffer_size = 0, no buffer will be created. buffer_ratio (float): The chance / possibility to use the images previously stored in the buffer. Default: 0.5. """ def __init__(self, buffer_size, buffer_ratio=0.5): self.buffer_size = buffer_size # create an empty buffer if self.buffer_size > 0: self.img_num = 0 self.image_buffer = [] self.buffer_ratio = buffer_ratio
[文档] def query(self, images): """Query current image batch using a history of generated images. Args: images (Tensor): Current image batch without history information. """ if self.buffer_size == 0: # if the buffer size is 0, do nothing return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) # if the buffer is not full, keep inserting current images if self.img_num < self.buffer_size: self.img_num = self.img_num + 1 self.image_buffer.append(image) return_images.append(image) else: use_buffer = np.random.random() < self.buffer_ratio # by self.buffer_ratio, the buffer will return a previously # stored image, and insert the current image into the buffer if use_buffer: random_id = np.random.randint(0, self.buffer_size) image_tmp = self.image_buffer[random_id].clone() self.image_buffer[random_id] = image return_images.append(image_tmp) # by (1 - self.buffer_ratio), the buffer will return the # current image else: return_images.append(image) # collect all the images and return return_images = torch.cat(return_images, 0) return return_images