注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.models.losses.gradient_loss 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import LOSSES
from .pixelwise_loss import l1_loss
_reduction_modes = ['none', 'mean', 'sum']
[文档]@LOSSES.register_module()
class GradientLoss(nn.Module):
"""Gradient loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super().__init__()
self.loss_weight = loss_weight
self.reduction = reduction
if self.reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {self.reduction}. '
f'Supported ones are: {_reduction_modes}')
[文档] def forward(self, pred, target, weight=None):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
kx = torch.Tensor([[1, 0, -1], [2, 0, -2],
[1, 0, -1]]).view(1, 1, 3, 3).to(target)
ky = torch.Tensor([[1, 2, 1], [0, 0, 0],
[-1, -2, -1]]).view(1, 1, 3, 3).to(target)
pred_grad_x = F.conv2d(pred, kx, padding=1)
pred_grad_y = F.conv2d(pred, ky, padding=1)
target_grad_x = F.conv2d(target, kx, padding=1)
target_grad_y = F.conv2d(target, ky, padding=1)
loss = (
l1_loss(
pred_grad_x, target_grad_x, weight, reduction=self.reduction) +
l1_loss(
pred_grad_y, target_grad_y, weight, reduction=self.reduction))
return loss * self.loss_weight