mmedit.models.editors.nafnet.naf_avgpool2d 源代码
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
[文档]class NAFAvgPool2d(BaseModule):
"""Average Pooling 2D used in NAFNet.
Note: this is different from the normal AvgPool2d in pytorch.
According to:
Improving Image Restoration by Revisiting Global Information Aggregation
statistics are aggregated in a local region for each pixel
rather than the global average pooling.
"""
def __init__(self,
kernel_size=None,
base_size=None,
auto_pad=True,
fast_imp=False,
train_size=None):
super().__init__()
self.kernel_size = kernel_size
self.base_size = base_size
self.auto_pad = auto_pad
# only used for fast implementation
self.fast_imp = fast_imp
self.rs = [5, 4, 3, 2, 1]
self.max_r1 = self.rs[0]
self.max_r2 = self.rs[0]
self.train_size = train_size
[文档] def extra_repr(self) -> str:
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp)
[文档] def forward(self, x):
if self.kernel_size is None and self.base_size:
train_size = self.train_size
if isinstance(self.base_size, int):
self.base_size = (self.base_size, self.base_size)
self.kernel_size = list(self.base_size)
self.kernel_size[
0] = x.shape[2] * self.base_size[0] // train_size[-2]
self.kernel_size[
1] = x.shape[3] * self.base_size[1] // train_size[-1]
# only used for fast implementation
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(
-1):
return F.adaptive_avg_pool2d(x, 1)
if self.fast_imp: # Non-equivalent implementation but faster
h, w = x.shape[2:]
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
out = F.adaptive_avg_pool2d(x, 1)
else:
r1 = [r for r in self.rs if h % r == 0][0]
r2 = [r for r in self.rs if w % r == 0][0]
# reduction_constraint
r1 = min(self.max_r1, r1)
r2 = min(self.max_r2, r2)
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
n, c, h, w = s.shape
k1, k2 = min(h - 1, self.kernel_size[0] //
r1), min(w - 1, self.kernel_size[1] // r2)
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] -
s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (
k1 * k2)
out = torch.nn.functional.interpolate(
out, scale_factor=(r1, r2))
else:
n, c, h, w = x.shape
s = x.cumsum(dim=-1).cumsum_(dim=-2)
s = torch.nn.functional.pad(s,
(1, 0, 1, 0)) # pad 0 for convenience
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1,
k2:], s[:, :,
k1:, :-k2], s[:, :,
k1:,
k2:]
out = s4 + s1 - s2 - s3
out = out / (k1 * k2)
if self.auto_pad:
n, c, h, w = x.shape
_h, _w = out.shape[2:]
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2,
(h - _h + 1) // 2)
out = torch.nn.functional.pad(out, pad2d, mode='replicate')
return out
[文档]def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
"""Replace all layers with AvgPool2d."""
for n, m in model.named_children():
if len(list(m.children())) > 0:
# compound module, go inside it
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
if isinstance(m, nn.AdaptiveAvgPool2d):
pool = NAFAvgPool2d(
base_size=base_size, fast_imp=fast_imp, train_size=train_size)
assert m.output_size == 1
setattr(model, n, pool)
[文档]class Local_Base():
"""Local Base class to use global average pooling.
args:
train_size: training image size
"""
[文档] def convert(self, *args, train_size, **kwargs):
replace_layers(self, *args, train_size=train_size, **kwargs)
imgs = torch.rand(train_size)
with torch.no_grad():
self.forward(imgs)