Skip to content

Commit

Permalink
added fill value to Cutout constructor (albumentations-team#267)
Browse files Browse the repository at this point in the history
* added fill value to Cutout constructory

* fixed except clause

* flake8 fixes

* Delete test.py

dummy file

* Update transforms_interface.py

remove unnecessary if statements

* added fill value to CoarseDropouty

* fixed update params

* switched fill_value to optional in apply()

* fixing test fails

* fixed signatures

* flake8 fixes

* more flake8 fixes
  • Loading branch information
bfialkoff authored and ternaus committed Jun 25, 2019
1 parent f2afcda commit 7ee68e3
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
5 changes: 2 additions & 3 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,11 @@ def normalize(img, mean, std, max_pixel_value=255.0):
return img


def cutout(img, holes):
def cutout(img, holes, fill_value=0):
# Make a copy of the input image since we don't want to modify it directly
img = img.copy()

for x1, y1, x2, y2 in holes:
img[y1: y2, x1: x2] = 0
img[y1: y2, x1: x2] = fill_value
return img


Expand Down
19 changes: 10 additions & 9 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def targets_as_params(self):
return ['cropping_bbox']

def get_transform_init_args_names(self):
return ('max_part_shift', )
return ('max_part_shift',)


class RandomSizedCrop(DualTransform):
Expand Down Expand Up @@ -867,7 +867,7 @@ def __init__(self, alpha=1, sigma=50, alpha_affine=50, interpolation=cv2.INTER_L

def apply(self, img, random_state=None, interpolation=cv2.INTER_LINEAR, **params):
return F.elastic_transform(img, self.alpha, self.sigma, self.alpha_affine, interpolation,
self.border_mode, self. value, np.random.RandomState(random_state),
self.border_mode, self.value, np.random.RandomState(random_state),
self.approximate)

def get_params(self):
Expand Down Expand Up @@ -922,15 +922,16 @@ class Cutout(ImageOnlyTransform):
| https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py
"""

def __init__(self, num_holes=8, max_h_size=8, max_w_size=8, always_apply=False, p=0.5):
def __init__(self, num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, always_apply=False, p=0.5):
super(Cutout, self).__init__(always_apply, p)
self.num_holes = num_holes
self.max_h_size = max_h_size
self.max_w_size = max_w_size
self.fill_value = fill_value
warnings.warn("This class has been deprecated. Please use CoarseDropout", DeprecationWarning)

def apply(self, image, holes=[], **params):
return F.cutout(image, holes)
def apply(self, image, fill_value=0, holes=[], **params):
return F.cutout(image, holes, fill_value)

def get_params_dependent_on_targets(self, params):
img = params['image']
Expand Down Expand Up @@ -986,21 +987,21 @@ class CoarseDropout(ImageOnlyTransform):

def __init__(self, max_holes=8, max_height=8, max_width=8,
min_holes=None, min_height=None, min_width=None,
always_apply=False, p=0.5):
fill_value=0, always_apply=False, p=0.5):
super(CoarseDropout, self).__init__(always_apply, p)
self.max_holes = max_holes
self.max_height = max_height
self.max_width = max_width
self.min_holes = min_holes if min_holes is not None else max_holes
self.min_height = min_height if min_height is not None else max_height
self.min_width = min_width if min_width is not None else max_width

self.fill_value = fill_value
assert 0 < self.min_holes <= self.max_holes
assert 0 < self.min_height <= self.max_height
assert 0 < self.min_width <= self.max_width

def apply(self, image, holes=[], **params):
return F.cutout(image, holes)
def apply(self, image, fill_value=0, holes=[], **params):
return F.cutout(image, holes, fill_value)

def get_params_dependent_on_targets(self, params):
img = params['image']
Expand Down
6 changes: 5 additions & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def targets(self):
def update_params(self, params, **kwargs):
if hasattr(self, 'interpolation'):
params['interpolation'] = self.interpolation
params.update({'cols': kwargs['image'].shape[1], 'rows': kwargs['image'].shape[0]})
if hasattr(self, 'fill_value'):
params['interpolation'] = self.fill_value
params.update(
{'cols': kwargs['image'].shape[1],
'rows': kwargs['image'].shape[0]})
return params

@property
Expand Down
5 changes: 4 additions & 1 deletion albumentations/imgaug/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import imgaug as ia
from imgaug import augmenters as iaa
try:
from imgaug import augmenters as iaa
except ImportError:
import imgaug.imgaug.augmenters as iaa

from ..augmentations.bbox_utils import convert_bboxes_from_albumentations, \
convert_bboxes_to_albumentations
Expand Down

0 comments on commit 7ee68e3

Please sign in to comment.