Skip to content

Commit

Permalink
Fix the Lambda transform (albumentations-team#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz authored May 18, 2019
1 parent 72b170f commit 822c761
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
4 changes: 4 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,3 +1136,7 @@ def py3round(number):
return int(2.0 * round(number / 2.0))

return int(round(number))


def noop(input_obj, **params):
return input_obj
23 changes: 17 additions & 6 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,19 +1899,30 @@ class Lambda(NoOp):
def __init__(self, image=None, mask=None, keypoint=None, bbox=None, always_apply=False, p=1.0):
super(Lambda, self).__init__(always_apply, p)

self._targets = super(Lambda, self).targets

self.custom_apply_fns = {target_name: F.noop for target_name in ('image', 'mask', 'keypoint', 'bbox')}
for target_name, custom_apply_fn in {'image': image, 'mask': mask, 'keypoint': keypoint, 'bbox': bbox}.items():
if custom_apply_fn is not None:
if isinstance(custom_apply_fn, LambdaType):
warnings.warn('Using lambda is incompatible with multiprocessing. '
'Consider using regular functions or partial().')

self._targets[target_name] = custom_apply_fn
self.custom_apply_fns[target_name] = custom_apply_fn

@property
def targets(self):
return self._targets
def apply(self, img, **params):
fn = self.custom_apply_fns['image']
return fn(img, **params)

def apply_to_mask(self, mask, **params):
fn = self.custom_apply_fns['mask']
return fn(mask, **params)

def apply_to_bbox(self, bbox, **params):
fn = self.custom_apply_fns['bbox']
return fn(bbox, **params)

def apply_to_keypoint(self, keypoint, **params):
fn = self.custom_apply_fns['keypoint']
return fn(keypoint, **params)

def to_dict(self):
raise NotImplementedError('Lambda is not serializable')
25 changes: 22 additions & 3 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,27 @@ def one_hot_mask(mask, num_channels, **kwargs):
new_mask = np.eye(num_channels, dtype=np.uint8)[mask]
return new_mask

aug = A.Lambda(image=negate_image, mask=partial(one_hot_mask, num_channels=16), p=1)

output = aug(image=np.ones((10, 10, 3), dtype=np.float32), mask=np.tile(np.arange(0, 10), (10, 1)))
def vflip_bbox(bbox, **kwargs):
return F.bbox_vflip(bbox, **kwargs)

def vflip_keypoint(keypoint, **kwargs):
return F.keypoint_vflip(keypoint, **kwargs)

aug = A.Lambda(
image=negate_image,
mask=partial(one_hot_mask, num_channels=16),
bbox=vflip_bbox,
keypoint=vflip_keypoint,
p=1,
)

output = aug(
image=np.ones((10, 10, 3), dtype=np.float32),
mask=np.tile(np.arange(0, 10), (10, 1)),
bboxes=[[10, 15, 25, 35]],
keypoints=[[20, 30, 40, 50]],
)
assert (output['image'] < 0).all()
assert output['mask'].shape[2] == 16 # num_channels
assert output['bboxes'] == [F.bbox_vflip([10, 15, 25, 35], 10, 10)]
assert output['keypoints'] == [F.keypoint_vflip([20, 30, 40, 50], 10, 10)]

0 comments on commit 822c761

Please sign in to comment.