Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper decorator to preserve shape of input image #106

Merged
merged 6 commits into from
Oct 22, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add second decorator to preserve only channels dim
  • Loading branch information
BloodAxe committed Oct 19, 2018
commit 719bab89f82f199ca61af2d40e03c7cc2e4552af
31 changes: 24 additions & 7 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ def wrapped_function(img, *args, **kwargs):
return wrapped_function


def preserve_channel_dim(func):
"""
Decorator that preserves shape channels dim after applying augmentation.
OpenCV functions tends to squeeze last dimension if it's 1. So we put it back
"""

@wraps(func)
def wrapped_function(img, *args, **kwargs):
shape = img.shape
result = func(img, *args, **kwargs)
if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2:
result = np.expand_dims(result, axis=-1)
return result

return wrapped_function


def vflip(img):
return np.ascontiguousarray(img[::-1, ...])

Expand Down Expand Up @@ -96,29 +113,29 @@ def cutout(img, num_holes, max_h_size, max_w_size):
return img


@preserve_shape
@preserve_channel_dim
def rotate(img, angle, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101):
height, width = img.shape[:2]
matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0)
img = cv2.warpAffine(img, matrix, (width, height), flags=interpolation, borderMode=border_mode)
return img


@preserve_shape
@preserve_channel_dim
def scale(img, scale, interpolation=cv2.INTER_LINEAR):
height, width = img.shape[:2]
new_height, new_width = int(height * scale), int(width * scale)
img = cv2.resize(img, (new_width, new_height), interpolation=interpolation)
return img


@preserve_shape
@preserve_channel_dim
def resize(img, height, width, interpolation=cv2.INTER_LINEAR):
img = cv2.resize(img, (width, height), interpolation=interpolation)
return img


@preserve_shape
@preserve_channel_dim
def shift_scale_rotate(img, angle, scale, dx, dy, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101):
height, width = img.shape[:2]
center = (width / 2, height / 2)
Expand Down Expand Up @@ -250,7 +267,7 @@ def clahe(img, clip_limit=2.0, tile_grid_size=(8, 8)):
return img


@preserve_shape
@preserve_channel_dim
def pad(img, min_height, min_width, border_mode=cv2.BORDER_REFLECT_101, value=[0, 0, 0]):
height, width = img.shape[:2]

Expand Down Expand Up @@ -296,12 +313,12 @@ def _func_max_size(img, max_size, interpolation, func):
return img


@preserve_shape
@preserve_channel_dim
def longest_max_size(img, max_size, interpolation):
return _func_max_size(img, max_size, interpolation, max)


@preserve_shape
@preserve_channel_dim
def smallest_max_size(img, max_size, interpolation):
return _func_max_size(img, max_size, interpolation, min)

Expand Down