mmedit.models.editors.liif.liif_net 源代码
from abc import abstractmethod
import torch
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmedit.registry import MODELS
from mmedit.utils import make_coord
[文档]class LIIFNet(BaseModule):
"""LIIF net for single image super-resolution, CVPR, 2021.
Paper: Learning Continuous Image Representation with
Local Implicit Image Function
The subclasses should define `generator` with `encoder` and `imnet`,
and overwrite the function `gen_feature`.
If `encoder` does not contain `mid_channels`, `__init__` should be
overwrite.
Args:
encoder (dict): Config for the generator.
imnet (dict): Config for the imnet.
local_ensemble (bool): Whether to use local ensemble. Default: True.
feat_unfold (bool): Whether to use feature unfold. Default: True.
cell_decode (bool): Whether to use cell decode. Default: True.
eval_bsize (int): Size of batched predict. Default: None.
"""
def __init__(self,
encoder,
imnet,
local_ensemble=True,
feat_unfold=True,
cell_decode=True,
eval_bsize=None):
super().__init__()
self.local_ensemble = local_ensemble
self.feat_unfold = feat_unfold
self.cell_decode = cell_decode
self.eval_bsize = eval_bsize
# model
self.encoder = MODELS.build(encoder)
imnet_in_dim = self.encoder.mid_channels
if self.feat_unfold:
imnet_in_dim *= 9
imnet_in_dim += 2 # attach coordinates
if self.cell_decode:
imnet_in_dim += 2
imnet['in_dim'] = imnet_in_dim
self.imnet = MODELS.build(imnet)
[文档] def forward(self, x, coord, cell, test_mode=False):
"""Forward function.
Args:
x: input tensor.
coord (Tensor): coordinates tensor.
cell (Tensor): cell tensor.
test_mode (bool): Whether in test mode or not. Default: False.
Returns:
pred (Tensor): output of model.
"""
feature = self.gen_feature(x)
if self.eval_bsize is None or not test_mode:
pred = self.query_rgb(feature, coord, cell)
else:
pred = self.batched_predict(feature, coord, cell)
return pred
[文档] def query_rgb(self, feature, coord, cell=None):
"""Query RGB value of GT.
Adapted from 'https://github.com/yinboc/liif.git'
'liif/models/liif.py'
Copyright (c) 2020, Yinbo Chen, under BSD 3-Clause License.
Args:
feature (Tensor): encoded feature.
coord (Tensor): coord tensor, shape (BHW, 2).
cell (Tensor | None): cell tensor. Default: None.
Returns:
result (Tensor): (part of) output.
"""
if self.imnet is None:
coord = coord.type(feature.type())
result = F.grid_sample(
feature,
coord.flip(-1).unsqueeze(1),
mode='nearest',
align_corners=False)
result = result[:, :, 0, :].permute(0, 2, 1)
return result
if self.feat_unfold:
feature = F.unfold(
feature, 3,
padding=1).view(feature.shape[0], feature.shape[1] * 9,
feature.shape[2], feature.shape[3])
if self.local_ensemble:
vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6
else:
vx_lst, vy_lst, eps_shift = [0], [0], 0
# field radius (global: [-1, 1])
radius_x = 2 / feature.shape[-2] / 2
radius_y = 2 / feature.shape[-1] / 2
feat_coord = make_coord(feature.shape[-2:], flatten=False) \
.permute(2, 0, 1) \
.unsqueeze(0).expand(feature.shape[0], 2, *feature.shape[-2:])
feat_coord = feat_coord.to(coord)
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()
coord_[:, :, 0] += vx * radius_x + eps_shift
coord_[:, :, 1] += vy * radius_y + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
coord_ = coord_.type(feature.type())
query_feat = F.grid_sample(
feature, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
feat_coord = feat_coord.type(coord_.type())
query_coord = F.grid_sample(
feat_coord, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
rel_coord = coord - query_coord
rel_coord[:, :, 0] *= feature.shape[-2]
rel_coord[:, :, 1] *= feature.shape[-1]
mid_tensor = torch.cat([query_feat, rel_coord], dim=-1)
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feature.shape[-2]
rel_cell[:, :, 1] *= feature.shape[-1]
mid_tensor = torch.cat([mid_tensor, rel_cell], dim=-1)
bs, q = coord.shape[:2]
pred = self.imnet(mid_tensor.view(bs * q, -1)).view(bs, q, -1)
preds.append(pred)
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9)
total_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble:
areas = areas[::-1]
result = 0
for pred, area in zip(preds, areas):
result = result + pred * (area / total_area).unsqueeze(-1)
return result
[文档] def batched_predict(self, x, coord, cell):
"""Batched predict.
Args:
x (Tensor): Input tensor.
coord (Tensor): coord tensor.
cell (Tensor): cell tensor.
Returns:
pred (Tensor): output of model.
"""
with torch.no_grad():
n = coord.shape[1]
left = 0
preds = []
while left < n:
right = min(left + self.eval_bsize, n)
pred = self.query_rgb(x, coord[:, left:right, :],
cell[:, left:right, :])
preds.append(pred)
left = right
pred = torch.cat(preds, dim=1)
return pred
@abstractmethod
[文档] def gen_feature(self, x):
"""Generate feature.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
@MODELS.register_module()
[文档]class LIIFEDSRNet(LIIFNet):
"""LIIF net based on EDSR.
Paper: Learning Continuous Image Representation with
Local Implicit Image Function
Args:
encoder (dict): Config for the generator.
imnet (dict): Config for the imnet.
local_ensemble (bool): Whether to use local ensemble. Default: True.
feat_unfold (bool): Whether to use feature unfold. Default: True.
cell_decode (bool): Whether to use cell decode. Default: True.
eval_bsize (int): Size of batched predict. Default: None.
"""
def __init__(self,
encoder,
imnet,
local_ensemble=True,
feat_unfold=True,
cell_decode=True,
eval_bsize=None):
super().__init__(
encoder=encoder,
imnet=imnet,
local_ensemble=local_ensemble,
feat_unfold=feat_unfold,
cell_decode=cell_decode,
eval_bsize=eval_bsize)
self.conv_first = self.encoder.conv_first
self.body = self.encoder.body
self.conv_after_body = self.encoder.conv_after_body
del self.encoder
[文档] def gen_feature(self, x):
"""Generate feature.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x = self.conv_first(x)
res = self.body(x)
res = self.conv_after_body(res)
res += x
return res
@MODELS.register_module()
[文档]class LIIFRDNNet(LIIFNet):
"""LIIF net based on RDN.
Paper: Learning Continuous Image Representation with
Local Implicit Image Function
Args:
encoder (dict): Config for the generator.
imnet (dict): Config for the imnet.
local_ensemble (bool): Whether to use local ensemble. Default: True.
feat_unfold (bool): Whether to use feat unfold. Default: True.
cell_decode (bool): Whether to use cell decode. Default: True.
eval_bsize (int): Size of batched predict. Default: None.
"""
def __init__(self,
encoder,
imnet,
local_ensemble=True,
feat_unfold=True,
cell_decode=True,
eval_bsize=None):
super().__init__(
encoder=encoder,
imnet=imnet,
local_ensemble=local_ensemble,
feat_unfold=feat_unfold,
cell_decode=cell_decode,
eval_bsize=eval_bsize)
self.sfe1 = self.encoder.sfe1
self.sfe2 = self.encoder.sfe2
self.rdbs = self.encoder.rdbs
self.gff = self.encoder.gff
self.num_blocks = self.encoder.num_blocks
del self.encoder
[文档] def gen_feature(self, x):
"""Generate feature.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
sfe1 = self.sfe1(x)
sfe2 = self.sfe2(sfe1)
x = sfe2
local_features = []
for i in range(self.num_blocks):
x = self.rdbs[i](x)
local_features.append(x)
x = self.gff(torch.cat(local_features, 1)) + sfe1
return x