注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.models.common.flow_warp 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
[文档]def flow_warp(x,
flow,
interpolation='bilinear',
padding_mode='zeros',
align_corners=True):
"""Warp an image or a feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
a two-channel, denoting the width and height relative offsets.
Note that the values are not normalized to [-1, 1].
interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
Default: 'bilinear'.
padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Whether align corners. Default: True.
Returns:
Tensor: Warped image or feature map.
"""
if x.size()[-2:] != flow.size()[1:3]:
raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
f'flow ({flow.size()[1:3]}) are not the same.')
_, _, h, w = x.size()
# create mesh grid
device = flow.device
grid_y, grid_x = torch.meshgrid(
torch.arange(0, h, device=device, dtype=x.dtype),
torch.arange(0, w, device=device, dtype=x.dtype))
grid = torch.stack((grid_x, grid_y), 2) # h, w, 2
grid.requires_grad = False
grid_flow = grid + flow
# scale grid_flow to [-1,1]
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
output = F.grid_sample(
x,
grid_flow,
mode=interpolation,
padding_mode=padding_mode,
align_corners=align_corners)
return output