diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 4ceb5a94c..43ffd8db9 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -1136,3 +1136,7 @@ def py3round(number): return int(2.0 * round(number / 2.0)) return int(round(number)) + + +def noop(input_obj, **params): + return input_obj diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 84b07244d..a3ec2d601 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -1899,19 +1899,30 @@ class Lambda(NoOp): def __init__(self, image=None, mask=None, keypoint=None, bbox=None, always_apply=False, p=1.0): super(Lambda, self).__init__(always_apply, p) - self._targets = super(Lambda, self).targets - + self.custom_apply_fns = {target_name: F.noop for target_name in ('image', 'mask', 'keypoint', 'bbox')} for target_name, custom_apply_fn in {'image': image, 'mask': mask, 'keypoint': keypoint, 'bbox': bbox}.items(): if custom_apply_fn is not None: if isinstance(custom_apply_fn, LambdaType): warnings.warn('Using lambda is incompatible with multiprocessing. ' 'Consider using regular functions or partial().') - self._targets[target_name] = custom_apply_fn + self.custom_apply_fns[target_name] = custom_apply_fn - @property - def targets(self): - return self._targets + def apply(self, img, **params): + fn = self.custom_apply_fns['image'] + return fn(img, **params) + + def apply_to_mask(self, mask, **params): + fn = self.custom_apply_fns['mask'] + return fn(mask, **params) + + def apply_to_bbox(self, bbox, **params): + fn = self.custom_apply_fns['bbox'] + return fn(bbox, **params) + + def apply_to_keypoint(self, keypoint, **params): + fn = self.custom_apply_fns['keypoint'] + return fn(keypoint, **params) def to_dict(self): raise NotImplementedError('Lambda is not serializable') diff --git a/tests/test_transforms.py b/tests/test_transforms.py index a77fca0e9..423b2fdcb 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -249,8 +249,27 @@ def one_hot_mask(mask, num_channels, **kwargs): new_mask = np.eye(num_channels, dtype=np.uint8)[mask] return new_mask - aug = A.Lambda(image=negate_image, mask=partial(one_hot_mask, num_channels=16), p=1) - - output = aug(image=np.ones((10, 10, 3), dtype=np.float32), mask=np.tile(np.arange(0, 10), (10, 1))) + def vflip_bbox(bbox, **kwargs): + return F.bbox_vflip(bbox, **kwargs) + + def vflip_keypoint(keypoint, **kwargs): + return F.keypoint_vflip(keypoint, **kwargs) + + aug = A.Lambda( + image=negate_image, + mask=partial(one_hot_mask, num_channels=16), + bbox=vflip_bbox, + keypoint=vflip_keypoint, + p=1, + ) + + output = aug( + image=np.ones((10, 10, 3), dtype=np.float32), + mask=np.tile(np.arange(0, 10), (10, 1)), + bboxes=[[10, 15, 25, 35]], + keypoints=[[20, 30, 40, 50]], + ) assert (output['image'] < 0).all() assert output['mask'].shape[2] == 16 # num_channels + assert output['bboxes'] == [F.bbox_vflip([10, 15, 25, 35], 10, 10)] + assert output['keypoints'] == [F.keypoint_vflip([20, 30, 40, 50], 10, 10)]