Skip to content

Commit

Permalink
Fixing keypoints augmentation bug (albumentations-team#188)
Browse files Browse the repository at this point in the history
* Fixing keypoints augmentation bug

* Fixed unit tests

* Fixed unit tests
  • Loading branch information
BloodAxe authored and ternaus committed Feb 14, 2019
1 parent 4c61d47 commit 97eb72d
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 24 deletions.
46 changes: 24 additions & 22 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,30 @@ def shift_scale_rotate(img, angle, scale, dx, dy, interpolation=cv2.INTER_LINEAR
return img


def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, interpolation, rows, cols, **params):
center = (0.5, 0.5)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
matrix[0, 2] += dx
matrix[1, 2] += dy
x = np.array([bbox[0], bbox[2], bbox[2], bbox[0]])
y = np.array([bbox[1], bbox[1], bbox[3], bbox[3]])
ones = np.ones(shape=(len(x)))
points_ones = np.vstack([x, y, ones]).transpose()
tr_points = matrix.dot(points_ones.T).T
return [min(tr_points[:, 0]), min(tr_points[:, 1]), max(tr_points[:, 0]), max(tr_points[:, 1])]


def keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols, **params):
height, width = rows, cols
center = (width / 2, height / 2)
x, y, a, s = keypoint
matrix = cv2.getRotationMatrix2D(center, angle, scale)
matrix[0, 2] += dx * width
matrix[1, 2] += dy * height
x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze()
return [x, y, a + math.radians(angle), s * scale]


def crop(img, x_min, y_min, x_max, y_max):
height, width = img.shape[:2]
if x_max <= x_min or y_max <= y_min:
Expand Down Expand Up @@ -629,19 +653,6 @@ def from_float(img, dtype, max_value=None):
return (img * max_value).astype(dtype)


def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, interpolation, rows, cols, **params):
center = (0.5, 0.5)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
matrix[0, 2] += dx
matrix[1, 2] += dy
x = np.array([bbox[0], bbox[2], bbox[2], bbox[0]])
y = np.array([bbox[1], bbox[1], bbox[3], bbox[3]])
ones = np.ones(shape=(len(x)))
points_ones = np.vstack([x, y, ones]).transpose()
tr_points = matrix.dot(points_ones.T).T
return [min(tr_points[:, 0]), min(tr_points[:, 1]), max(tr_points[:, 0]), max(tr_points[:, 1])]


def bbox_vflip(bbox, rows, cols):
"""Flip a bounding box vertically around the x-axis."""
x_min, y_min, x_max, y_max = bbox
Expand Down Expand Up @@ -853,12 +864,3 @@ def keypoint_random_crop(keypoint, crop_height, crop_width, h_start, w_start, ro
def keypoint_center_crop(bbox, crop_height, crop_width, rows, cols):
crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width)
return crop_keypoint_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)


def keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols, **params):
x, y, a, s = keypoint
matrix = cv2.getRotationMatrix2D(((cols - 1) * 0.5, (rows - 1) * 0.5), angle, scale)
matrix[0, 2] += dx
matrix[1, 2] += dy
x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze()
return [x, y, a + math.radians(angle), s * scale]
54 changes: 54 additions & 0 deletions notebooks/example_bbox_keypoint_rotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt


def visualize(image, keypoints, bboxes):
overlay = image.copy()
for kp in keypoints:
cv2.circle(overlay, (int(kp[0]), int(kp[1])), 20, (0, 200, 200),
thickness=2,
lineType=cv2.LINE_AA)

for box in bboxes:
cv2.rectangle(overlay, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (200, 0, 0),
thickness=2)

return overlay


def main():
image = cv2.imread('images/image_1.jpg')

keypoints = cv2.goodFeaturesToTrack(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY),
maxCorners=100,
qualityLevel=.5,
minDistance=5).squeeze(1)

bboxes = [(kp[0] - 10, kp[1] - 10, kp[0] + 10, kp[1] + 10) for kp in keypoints]

disp_image = visualize(image, keypoints, bboxes)
plt.figure(figsize=(10, 10))
plt.imshow(cv2.cvtColor(disp_image, cv2.COLOR_RGB2BGR))
plt.tight_layout()
plt.show()

aug = A.Compose([
A.ShiftScaleRotate(scale_limit=0.1, shift_limit=0.2, rotate_limit=10, always_apply=True)
], bbox_params={'format': 'pascal_voc', 'label_fields': ['bbox_labels']}, keypoint_params={'format': 'xy'})

for i in range(10):
data = aug(image=image, keypoints=keypoints, bboxes=bboxes, bbox_labels=np.ones(len(bboxes)))

aug_image = data['image']
aug_image = visualize(aug_image, data['keypoints'], data['bboxes'])

plt.figure(figsize=(10, 10))
plt.imshow(cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR))
plt.tight_layout()
plt.show()


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions tests/test_keypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def test_keypoint_scale(keypoint, expected, scale):


@pytest.mark.parametrize(['keypoint', 'expected', 'angle', 'scale', 'dx', 'dy'], [
[[50, 50, 0, 5], [110, 158, math.pi / 2, 10], 90, 2, 10, 10],
[[50, 50, 0, 5], [120, 160, math.pi / 2, 10], 90, 2, 0.1, 0.1],
])
def test_keypoint_shift_scale_rotate(keypoint, expected, angle, scale, dx, dy):
actual = F.keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows=100, cols=200)
np.testing.assert_allclose(actual, expected, atol=1e-7)
np.testing.assert_allclose(actual, expected, rtol=1e-4)

0 comments on commit 97eb72d

Please sign in to comment.