mmedit.models.editors.swinir.swinir_utils 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import collections.abc
from itertools import repeat
# From PyTorch internals
[文档]def _ntuple(n):
"""A `to_tuple` function generator. It returns a function, this function
will repeat the input to a tuple of length ``n`` if the input is not an
Iterable object, otherwise, return the input directly.
Args:
n (int): The number of the target length.
"""
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
[文档]def drop_path(x,
drop_prob: float = 0.,
training: bool = False,
scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
This is the same as the DropConnect impl I created for
EfficientNet, etc networks, however, the original name is misleading
as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion:
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956
I've opted for changing the layer and argument names to 'drop path'
rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
[文档]def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
return windows
[文档]def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x