Shortcuts

注意

您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志代码文档 以了解更多。

mmedit.models.common.upsample 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F

from .sr_backbone_utils import default_init_weights


[文档]class PixelShufflePack(nn.Module): """Pixel Shuffle upsample layer. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. scale_factor (int): Upsample ratio. upsample_kernel (int): Kernel size of Conv layer to expand channels. Returns: Upsampled feature map. """ def __init__(self, in_channels, out_channels, scale_factor, upsample_kernel): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.scale_factor = scale_factor self.upsample_kernel = upsample_kernel self.upsample_conv = nn.Conv2d( self.in_channels, self.out_channels * scale_factor * scale_factor, self.upsample_kernel, padding=(self.upsample_kernel - 1) // 2) self.init_weights()
[文档] def init_weights(self): """Initialize weights for PixelShufflePack.""" default_init_weights(self, 1)
[文档] def forward(self, x): """Forward function for PixelShufflePack. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ x = self.upsample_conv(x) x = F.pixel_shuffle(x, self.scale_factor) return x