mmedit.models.editors.stylegan3.stylegan3_utils 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
try:
from mmcv.ops import filter2d, upsample2d
except ImportError:
upsample2d = None
print(
'Warning: mmcv.ops.filter2d and mmcv.ops.upsample2d are not available.'
)
[文档]def apply_integer_translation(x, tx, ty):
_N, _C, H, W = x.shape
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
ix = tx.round().to(torch.int64)
iy = ty.round().to(torch.int64)
z = torch.zeros_like(x)
m = torch.zeros_like(x)
if abs(ix) < W and abs(iy) < H:
y = x[:, :, max(-iy, 0):H + min(-iy, 0), max(-ix, 0):W + min(-ix, 0)]
z[:, :, max(iy, 0):H + min(iy, 0), max(ix, 0):W + min(ix, 0)] = y
m[:, :, max(iy, 0):H + min(iy, 0), max(ix, 0):W + min(ix, 0)] = 1
return z, m
[文档]def sinc(x):
y = (x * np.pi).abs()
z = torch.sin(y) / y.clamp(1e-30, float('inf'))
return torch.where(y < 1e-30, torch.ones_like(x), z)
[文档]def apply_fractional_translation(x, tx, ty, a=3):
_N, _C, H, W = x.shape
tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
ix = tx.floor().to(torch.int64)
iy = ty.floor().to(torch.int64)
fx = tx - ix
fy = ty - iy
b = a - 1
z = torch.zeros_like(x)
zx0 = max(ix - b, 0)
zy0 = max(iy - b, 0)
zx1 = min(ix + a, 0) + W
zy1 = min(iy + a, 0) + H
if zx0 < zx1 and zy0 < zy1:
taps = torch.arange(a * 2, device=x.device) - b
filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
y = x
y = filter2d(y, filter_x / filter_x.sum(), padding=[b, a, 0, 0])
y = filter2d(y, filter_y / filter_y.sum(), padding=[0, 0, b, a])
y = y[:, :,
max(b - iy, 0):H + b + a + min(-iy - a, 0),
max(b - ix, 0):W + b + a + min(-ix - a, 0)]
z[:, :, zy0:zy1, zx0:zx1] = y
m = torch.zeros_like(x)
mx0 = max(ix + a, 0)
my0 = max(iy + a, 0)
mx1 = min(ix - b, 0) + W
my1 = min(iy - b, 0) + H
if mx0 < mx1 and my0 < my1:
m[:, :, my0:my1, mx0:mx1] = 1
return z, m
[文档]def rotation_matrix(angle):
angle = torch.as_tensor(angle).to(torch.float32)
mat = torch.eye(3, device=angle.device)
mat[0, 0] = angle.cos()
mat[0, 1] = angle.sin()
mat[1, 0] = -angle.sin()
mat[1, 1] = angle.cos()
return mat
[文档]def lanczos_window(x, a):
x = x.abs() / a
return torch.where(x < 1, sinc(x), torch.zeros_like(x))
[文档]def construct_affine_bandlimit_filter(mat,
a=3,
amax=16,
aflt=64,
up=4,
cutoff_in=1,
cutoff_out=1):
assert a <= amax < aflt
mat = torch.as_tensor(mat).to(torch.float32)
# Construct 2D filter taps in input & output coordinate spaces.
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up -
aflt).roll(1 - aflt * up)
yi, xi = torch.meshgrid(taps, taps)
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
# Convolution of two oriented 2D sinc filters.
fin = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
fout = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
f = torch.fft.ifftn(torch.fft.fftn(fin) * torch.fft.fftn(fout)).real
# Convolution of two oriented 2D Lanczos windows.
wi = lanczos_window(xi, a) * lanczos_window(yi, a)
wo = lanczos_window(xo, a) * lanczos_window(yo, a)
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
# Construct windowed FIR filter.
f = f * w
# Finalize.
c = (aflt - amax) * up
f = f.roll([aflt * up - 1] * 2, dims=[0, 1])[c:-c, c:-c]
f = torch.nn.functional.pad(f,
[0, 1, 0, 1]).reshape(amax * 2, up, amax * 2,
up)
f = f / f.sum([0, 2], keepdim=True) / (up**2)
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
return f
[文档]def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
_N, _C, H, W = x.shape
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
# Construct filter.
f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
p = f.shape[0] // 2
# Construct sampling grid.
theta = mat.inverse()
theta[:2, 2] *= 2
theta[0, 2] += 1 / up / W
theta[1, 2] += 1 / up / H
theta[0, :] *= W / (W + p / up * 2)
theta[1, :] *= H / (H + p / up * 2)
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
# Resample image.
y = upsample2d(input=x, filter=f, up=up, padding=p)
z = torch.nn.functional.grid_sample(
y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
# Form mask.
m = torch.zeros_like(y)
c = p * 2 + 1
m[:, :, c:-c, c:-c] = 1
m = torch.nn.functional.grid_sample(
m, g, mode='nearest', padding_mode='zeros', align_corners=False)
return z, m
[文档]def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
mat = rotation_matrix(angle)
return apply_affine_transformation(
x, mat, a=a, amax=a * 2, **filter_kwargs)
[文档]def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
mat = rotation_matrix(-angle)
f = construct_affine_bandlimit_filter(
mat, a=a, amax=a * 2, up=1, **filter_kwargs)
y = filter2d(input=x, filter=f)
m = torch.zeros_like(y)
c = f.shape[0] // 2
m[:, :, c:-c, c:-c] = 1
return y, m