Shortcuts

mmedit.models.utils.tensor_utils 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch


[文档]def get_unknown_tensor(trimap, unknown_value=128 / 255): """Get 1-channel unknown area tensor from the 3 or 1-channel trimap tensor. Args: trimap (Tensor): Tensor with shape (N, 3, H, W) or (N, 1, H, W). unknown_value (float): Scalar value indicating unknown region in trimap. If trimap is pre-processed using `'rescale_to_zero_one'`, then 0 for bg, 128/255 for unknown, 1 for fg, and unknown_value should set to 128 / 255. If trimap is pre-processed by :meth:`FormatTrimap(to_onehot=False)`, then 0 for bg, 1 for unknown, 2 for fg and unknown_value should set to 1. If trimap is pre-processed by :meth:`FormatTrimap(to_onehot=True)`, then trimap is 3-channeled, and this value is not used. Returns: Tensor: Unknown area mask of shape (N, 1, H, W). """ if trimap.shape[1] == 3: # The three channels correspond to (bg mask, unknown mask, fg mask) # respectively. weight = trimap[:, 1:2, :, :].float() # elif 'to_onehot' in meta[0]: # key 'to_onehot' is added by pipeline `FormatTrimap` # 0 for bg, 1 for unknown, 2 for fg # weight = trimap.eq(1).float() else: # trimap is simply processed by pipeline `RescaleToZeroOne` # 0 for bg, 128/255 for unknown, 1 for fg weight = trimap.eq(unknown_value).float() return weight
[文档]def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: """Normalize vector with it's lengths at the last dimension. If `vector` is two-dimension tensor, this function is same as L2 normalization. Args: vector (torch.Tensor): Vectors to be normalized. Returns: torch.Tensor: Vectors after normalization. """ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))