注意
您正在阅读 MMEditing 0.x。 MMEditing 0.x 会在 2022 年末开始逐步停止维护,建议您及时升级到 MMEditing 1.0 版本,享受由 OpenMMLab 2.0 带来的更多新特性和更佳的性能表现。阅读 MMEditing 1.0 的发版日志、 代码 和 文档 以了解更多。
mmedit.datasets.pipelines.matting_aug 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import random
import cv2
import mmcv
import numpy as np
from mmcv.fileio import FileClient
from ..registry import PIPELINES
from .utils import adjust_gamma, random_choose_unknown
def add_gaussian_noise(img, mu, sigma):
img = img.astype(np.float32)
gauss_noise = np.random.normal(mu, sigma, img.shape)
noisy_img = img + gauss_noise
noisy_img = np.clip(noisy_img, 0, 255)
return noisy_img
[文档]@PIPELINES.register_module()
class MergeFgAndBg:
"""Composite foreground image and background image with alpha.
Required keys are "alpha", "fg" and "bg", added key is "merged".
"""
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
alpha = results['alpha'][..., None].astype(np.float32) / 255.
fg = results['fg']
bg = results['bg']
merged = fg * alpha + (1. - alpha) * bg
results['merged'] = merged
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
[文档]@PIPELINES.register_module()
class GenerateTrimap:
"""Using random erode/dilate to generate trimap from alpha matte.
Required key is "alpha", added key is "trimap".
Args:
kernel_size (int | tuple[int]): The range of random kernel_size of
erode/dilate; int indicates a fixed kernel_size. If `random` is set
to False and kernel_size is a tuple of length 2, then it will be
interpreted as (erode kernel_size, dilate kernel_size). It should
be noted that the kernel of the erosion and dilation has the same
height and width.
iterations (int | tuple[int], optional): The range of random iterations
of erode/dilate; int indicates a fixed iterations. If `random` is
set to False and iterations is a tuple of length 2, then it will be
interpreted as (erode iterations, dilate iterations). Default to 1.
random (bool, optional): Whether use random kernel_size and iterations
when generating trimap. See `kernel_size` and `iterations` for more
information.
"""
def __init__(self, kernel_size, iterations=1, random=True):
if isinstance(kernel_size, int):
kernel_size = kernel_size, kernel_size + 1
elif not mmcv.is_tuple_of(kernel_size, int) or len(kernel_size) != 2:
raise ValueError('kernel_size must be an int or a tuple of 2 int, '
f'but got {kernel_size}')
if isinstance(iterations, int):
iterations = iterations, iterations + 1
elif not mmcv.is_tuple_of(iterations, int) or len(iterations) != 2:
raise ValueError('iterations must be an int or a tuple of 2 int, '
f'but got {iterations}')
self.random = random
if self.random:
min_kernel, max_kernel = kernel_size
self.iterations = iterations
self.kernels = [
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))
for size in range(min_kernel, max_kernel)
]
else:
erode_ksize, dilate_ksize = kernel_size
self.iterations = iterations
self.kernels = [
cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(erode_ksize, erode_ksize)),
cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(dilate_ksize, dilate_ksize))
]
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
alpha = results['alpha']
if self.random:
kernel_num = len(self.kernels)
erode_kernel_idx = np.random.randint(kernel_num)
dilate_kernel_idx = np.random.randint(kernel_num)
min_iter, max_iter = self.iterations
erode_iter = np.random.randint(min_iter, max_iter)
dilate_iter = np.random.randint(min_iter, max_iter)
else:
erode_kernel_idx, dilate_kernel_idx = 0, 1
erode_iter, dilate_iter = self.iterations
eroded = cv2.erode(
alpha, self.kernels[erode_kernel_idx], iterations=erode_iter)
dilated = cv2.dilate(
alpha, self.kernels[dilate_kernel_idx], iterations=dilate_iter)
trimap = np.zeros_like(alpha)
trimap.fill(128)
trimap[eroded >= 255] = 255
trimap[dilated <= 0] = 0
results['trimap'] = trimap.astype(np.float32)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(kernels={self.kernels}, iterations={self.iterations}, '
f'random={self.random})')
return repr_str
[文档]@PIPELINES.register_module()
class GenerateTrimapWithDistTransform:
"""Generate trimap with distance transform function.
Args:
dist_thr (int, optional): Distance threshold. Area with alpha value
between (0, 255) will be considered as initial unknown area. Then
area with distance to unknown area smaller than the distance
threshold will also be consider as unknown area. Defaults to 20.
random (bool, optional): If True, use random distance threshold from
[1, dist_thr). If False, use `dist_thr` as the distance threshold
directly. Defaults to True.
"""
def __init__(self, dist_thr=20, random=True):
if not (isinstance(dist_thr, int) and dist_thr >= 1):
raise ValueError('dist_thr must be an int that is greater than 1, '
f'but got {dist_thr}')
self.dist_thr = dist_thr
self.random = random
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
alpha = results['alpha']
# image dilation implemented by Euclidean distance transform
known = (alpha == 0) | (alpha == 255)
dist_to_unknown = cv2.distanceTransform(
known.astype(np.uint8), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
dist_thr = np.random.randint(
1, self.dist_thr) if self.random else self.dist_thr
unknown = dist_to_unknown <= dist_thr
trimap = (alpha == 255) * 255
trimap[unknown] = 128
results['trimap'] = trimap.astype(np.uint8)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(dist_thr={self.dist_thr}, random={self.random})'
return repr_str
[文档]@PIPELINES.register_module()
class CompositeFg:
"""Composite foreground with a random foreground.
This class composites the current training sample with additional data
randomly (could be from the same dataset). With probability 0.5, the sample
will be composited with a random sample from the specified directory.
The composition is performed as:
.. math::
fg_{new} = \\alpha_1 * fg_1 + (1 - \\alpha_1) * fg_2
\\alpha_{new} = 1 - (1 - \\alpha_1) * (1 - \\alpha_2)
where :math:`(fg_1, \\alpha_1)` is from the current sample and
:math:`(fg_2, \\alpha_2)` is the randomly loaded sample. With the above
composition, :math:`\\alpha_{new}` is still in `[0, 1]`.
Required keys are "alpha" and "fg". Modified keys are "alpha" and "fg".
Args:
fg_dirs (str | list[str]): Path of directories to load foreground
images from.
alpha_dirs (str | list[str]): Path of directories to load alpha mattes
from.
interpolation (str): Interpolation method of `mmcv.imresize` to resize
the randomly loaded images.
"""
def __init__(self,
fg_dirs,
alpha_dirs,
interpolation='nearest',
io_backend='disk',
**kwargs):
self.fg_dirs = fg_dirs if isinstance(fg_dirs, list) else [fg_dirs]
self.alpha_dirs = alpha_dirs if isinstance(alpha_dirs,
list) else [alpha_dirs]
self.interpolation = interpolation
self.fg_list, self.alpha_list = self._get_file_list(
self.fg_dirs, self.alpha_dirs)
self.io_backend = io_backend
self.file_client = None
self.kwargs = kwargs
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
fg = results['fg']
alpha = results['alpha'].astype(np.float32) / 255.
h, w = results['fg'].shape[:2]
# randomly select fg
if np.random.rand() < 0.5:
idx = np.random.randint(len(self.fg_list))
fg2_bytes = self.file_client.get(self.fg_list[idx])
fg2 = mmcv.imfrombytes(fg2_bytes)
alpha2_bytes = self.file_client.get(self.alpha_list[idx])
alpha2 = mmcv.imfrombytes(alpha2_bytes, flag='grayscale')
alpha2 = alpha2.astype(np.float32) / 255.
fg2 = mmcv.imresize(fg2, (w, h), interpolation=self.interpolation)
alpha2 = mmcv.imresize(
alpha2, (w, h), interpolation=self.interpolation)
# the overlap of two 50% transparency will be 75%
alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
# if the result alpha is all-one, then we avoid composition
if np.any(alpha_tmp < 1):
# composite fg with fg2
fg = fg.astype(np.float32) * alpha[..., None] \
+ fg2.astype(np.float32) * (1 - alpha[..., None])
alpha = alpha_tmp
fg.astype(np.uint8)
results['fg'] = fg
results['alpha'] = (alpha * 255).astype(np.uint8)
return results
@staticmethod
def _get_file_list(fg_dirs, alpha_dirs):
all_fg_list = list()
all_alpha_list = list()
for fg_dir, alpha_dir in zip(fg_dirs, alpha_dirs):
fg_list = sorted(mmcv.scandir(fg_dir))
alpha_list = sorted(mmcv.scandir(alpha_dir))
# we assume the file names for fg and alpha are the same
assert len(fg_list) == len(alpha_list), (
f'{fg_dir} and {alpha_dir} should have the same number of '
f'images ({len(fg_list)} differs from ({len(alpha_list)})')
fg_list = [osp.join(fg_dir, fg) for fg in fg_list]
alpha_list = [osp.join(alpha_dir, alpha) for alpha in alpha_list]
all_fg_list.extend(fg_list)
all_alpha_list.extend(alpha_list)
return all_fg_list, all_alpha_list
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(fg_dirs={self.fg_dirs}, alpha_dirs={self.alpha_dirs}, '
f"interpolation='{self.interpolation}')")
return repr_str
[文档]@PIPELINES.register_module()
class GenerateSeg:
"""Generate segmentation mask from alpha matte.
Args:
kernel_size (int, optional): Kernel size for both erosion and
dilation. The kernel will have the same height and width.
Defaults to 5.
erode_iter_range (tuple, optional): Iteration of erosion.
Defaults to (10, 20).
dilate_iter_range (tuple, optional): Iteration of dilation.
Defaults to (15, 30).
num_holes_range (tuple, optional): Range of number of holes to
randomly select from. Defaults to (0, 3).
hole_sizes (list, optional): List of (h, w) to be selected as the
size of the rectangle hole.
Defaults to [(15, 15), (25, 25), (35, 35), (45, 45)].
blur_ksizes (list, optional): List of (h, w) to be selected as the
kernel_size of the gaussian blur.
Defaults to [(21, 21), (31, 31), (41, 41)].
"""
def __init__(self,
kernel_size=5,
erode_iter_range=(10, 20),
dilate_iter_range=(15, 30),
num_holes_range=(0, 3),
hole_sizes=[(15, 15), (25, 25), (35, 35), (45, 45)],
blur_ksizes=[(21, 21), (31, 31), (41, 41)]):
self.kernel_size = kernel_size
self.erode_iter_range = erode_iter_range
self.dilate_iter_range = dilate_iter_range
self.num_holes_range = num_holes_range
self.hole_sizes = hole_sizes
self.blur_ksizes = blur_ksizes
@staticmethod
def _crop_hole(img, start_point, hole_size):
"""Create a all-zero rectangle hole in the image.
Args:
img (np.ndarray): Source image.
start_point (tuple[int]): The top-left point of the rectangle.
hole_size (tuple[int]): The height and width of the rectangle hole.
Return:
np.ndarray: The cropped image.
"""
top, left = start_point
bottom = top + hole_size[0]
right = left + hole_size[1]
height, weight = img.shape[:2]
if top < 0 or bottom > height or left < 0 or right > weight:
raise ValueError(f'crop area {(left, top, right, bottom)} exceeds '
f'image size {(height, weight)}')
img[top:bottom, left:right] = 0
return img
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
alpha = results['alpha']
trimap = results['trimap']
# generete segmentation mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(self.kernel_size,
self.kernel_size))
seg = (alpha > 0.5).astype(np.float32)
seg = cv2.erode(
seg, kernel, iterations=np.random.randint(*self.erode_iter_range))
seg = cv2.dilate(
seg, kernel, iterations=np.random.randint(*self.dilate_iter_range))
# generate some holes in segmentation mask
num_holes = np.random.randint(*self.num_holes_range)
for _ in range(num_holes):
hole_size = random.choice(self.hole_sizes)
unknown = trimap == 128
start_point = random_choose_unknown(unknown, hole_size)
seg = self._crop_hole(seg, start_point, hole_size)
trimap = self._crop_hole(trimap, start_point, hole_size)
# perform gaussian blur to segmentation mask
seg = cv2.GaussianBlur(seg, random.choice(self.blur_ksizes), 0)
results['seg'] = seg.astype(np.uint8)
results['num_holes'] = num_holes
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f'(kernel_size={self.kernel_size}, '
f'erode_iter_range={self.erode_iter_range}, '
f'dilate_iter_range={self.dilate_iter_range}, '
f'num_holes_range={self.num_holes_range}, '
f'hole_sizes={self.hole_sizes}, blur_ksizes={self.blur_ksizes}')
return repr_str
[文档]@PIPELINES.register_module()
class PerturbBg:
"""Randomly add gaussian noise or gamma change to background image.
Required key is "bg", added key is "noisy_bg".
Args:
gamma_ratio (float, optional): The probability to use gamma correction
instead of gaussian noise. Defaults to 0.6.
"""
def __init__(self, gamma_ratio=0.6):
if gamma_ratio < 0 or gamma_ratio > 1:
raise ValueError('gamma_ratio must be a float between [0, 1], '
f'but got {gamma_ratio}')
self.gamma_ratio = gamma_ratio
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
if np.random.rand() >= self.gamma_ratio:
# generate gaussian noise with random gaussian N([-7, 7), [2, 6))
mu = np.random.randint(-7, 7)
sigma = np.random.randint(2, 6)
results['noisy_bg'] = add_gaussian_noise(results['bg'], mu, sigma)
else:
# adjust gamma in a range of N(1, 0.12)
gamma = np.random.normal(1, 0.12)
results['noisy_bg'] = adjust_gamma(results['bg'], gamma)
return results
def __repr__(self):
return self.__class__.__name__ + f'(gamma_ratio={self.gamma_ratio})'
[文档]@PIPELINES.register_module()
class GenerateSoftSeg:
"""Generate soft segmentation mask from input segmentation mask.
Required key is "seg", added key is "soft_seg".
Args:
fg_thr (float, optional): Threshold of the foreground in the normalized
input segmentation mask. Defaults to 0.2.
border_width (int, optional): Width of border to be padded to the
bottom of the mask. Defaults to 25.
erode_ksize (int, optional): Fixed kernel size of the erosion.
Defaults to 5.
dilate_ksize (int, optional): Fixed kernel size of the dilation.
Defaults to 5.
erode_iter_range (tuple, optional): Iteration of erosion.
Defaults to (10, 20).
dilate_iter_range (tuple, optional): Iteration of dilation.
Defaults to (3, 7).
blur_ksizes (list, optional): List of (h, w) to be selected as the
kernel_size of the gaussian blur.
Defaults to [(21, 21), (31, 31), (41, 41)].
"""
def __init__(self,
fg_thr=0.2,
border_width=25,
erode_ksize=3,
dilate_ksize=5,
erode_iter_range=(10, 20),
dilate_iter_range=(3, 7),
blur_ksizes=[(21, 21), (31, 31), (41, 41)]):
if not isinstance(fg_thr, float):
raise TypeError(f'fg_thr must be a float, but got {type(fg_thr)}')
if not isinstance(border_width, int):
raise TypeError(
f'border_width must be an int, but got {type(border_width)}')
if not isinstance(erode_ksize, int):
raise TypeError(
f'erode_ksize must be an int, but got {type(erode_ksize)}')
if not isinstance(dilate_ksize, int):
raise TypeError(
f'dilate_ksize must be an int, but got {type(dilate_ksize)}')
if (not mmcv.is_tuple_of(erode_iter_range, int)
or len(erode_iter_range) != 2):
raise TypeError('erode_iter_range must be a tuple of 2 int, '
f'but got {erode_iter_range}')
if (not mmcv.is_tuple_of(dilate_iter_range, int)
or len(dilate_iter_range) != 2):
raise TypeError('dilate_iter_range must be a tuple of 2 int, '
f'but got {dilate_iter_range}')
if not mmcv.is_list_of(blur_ksizes, tuple):
raise TypeError(
f'blur_ksizes must be a list of tuple, but got {blur_ksizes}')
self.fg_thr = fg_thr
self.border_width = border_width
self.erode_ksize = erode_ksize
self.dilate_ksize = dilate_ksize
self.erode_iter_range = erode_iter_range
self.dilate_iter_range = dilate_iter_range
self.blur_ksizes = blur_ksizes
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
seg = results['seg'].astype(np.float32) / 255
height, _ = seg.shape[:2]
seg[seg > self.fg_thr] = 1
# to align with the original repo, pad the bottom of the mask
seg = cv2.copyMakeBorder(seg, 0, self.border_width, 0, 0,
cv2.BORDER_REPLICATE)
# erode/dilate segmentation mask
erode_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(self.erode_ksize,
self.erode_ksize))
dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(self.dilate_ksize,
self.dilate_ksize))
seg = cv2.erode(
seg,
erode_kernel,
iterations=np.random.randint(*self.erode_iter_range))
seg = cv2.dilate(
seg,
dilate_kernel,
iterations=np.random.randint(*self.dilate_iter_range))
# perform gaussian blur to segmentation mask
seg = cv2.GaussianBlur(seg, random.choice(self.blur_ksizes), 0)
# remove the padded rows
seg = (seg * 255).astype(np.uint8)
seg = np.delete(seg, range(height, height + self.border_width), 0)
results['soft_seg'] = seg
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(fg_thr={self.fg_thr}, '
f'border_width={self.border_width}, '
f'erode_ksize={self.erode_ksize}, '
f'dilate_ksize={self.dilate_ksize}, '
f'erode_iter_range={self.erode_iter_range}, '
f'dilate_iter_range={self.dilate_iter_range}, '
f'blur_ksizes={self.blur_ksizes})')
return repr_str
[文档]@PIPELINES.register_module()
class TransformTrimap:
"""Transform trimap into two-channel and six-channel.
This class will generate a two-channel trimap composed of definite
foreground and background masks and encode it into a six-channel trimap
using Gaussian blurs of the generated two-channel trimap at three
different scales. The transformed trimap has 6 channels.
Required key is "trimap", added key is "transformed_trimap" and
"two_channel_trimap".
Adopted from the following repository:
https://github.com/MarcoForte/FBA_Matting/blob/master/networks/transforms.py.
"""
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
trimap = results['trimap']
assert len(trimap.shape) == 2
h, w = trimap.shape[:2]
# generate two-channel trimap
trimap2 = np.zeros((h, w, 2), dtype=np.uint8)
trimap2[trimap == 0, 0] = 255
trimap2[trimap == 255, 1] = 255
trimap_trans = np.zeros((h, w, 6), dtype=np.float32)
factor = np.array([[[0.02, 0.08, 0.16]]], dtype=np.float32)
for k in range(2):
if np.any(trimap2[:, :, k]):
dt_mask = -cv2.distanceTransform(255 - trimap2[:, :, k],
cv2.DIST_L2, 0)**2
dt_mask = dt_mask[..., None]
L = 320
trimap_trans[..., 3 * k:3 * k +
3] = np.exp(dt_mask / (2 * ((factor * L)**2)))
results['transformed_trimap'] = trimap_trans
results['two_channel_trimap'] = trimap2
return results
def __repr__(self):
repr_str = self.__class__.__name__
return repr_str