Skip to content

Commit

Permalink
Added ChannelDropout transform (albumentations-team#276)
Browse files Browse the repository at this point in the history
* Added ChannelDropout transform

* input parameters changed to range

* Update albumentations/augmentations/functional.py

Co-Authored-By: Eugene Khvedchenya <[email protected]>

* Update albumentations/augmentations/functional.py

Co-Authored-By: Eugene Khvedchenya <[email protected]>
  • Loading branch information
ternaus and BloodAxe authored Jun 20, 2019
1 parent 53fdd24 commit f2afcda
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Pixel-level transforms will change just an input image and will leave any additi

- [Blur](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Blur)
- [CLAHE](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CLAHE)
- [ChannelDropout](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ChannelDropout)
- [ChannelShuffle](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ChannelShuffle)
- [CoarseDropout](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CoarseDropout)
- [Cutout](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Cutout)
Expand Down
12 changes: 12 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,18 @@ def channel_shuffle(img, channels_shuffled):
return img


@preserve_shape
def channel_dropout(img, channels_to_drop, fill_value=0):
if len(img.shape) == 2 or img.shape[2] == 1:
raise NotImplementedError("Only one channel. ChannelDropout is not defined.")

img = img.copy()

img[..., channels_to_drop] = fill_value

return img


@preserve_shape
def gamma_transform(img, gamma):
if img.dtype == np.uint8:
Expand Down
50 changes: 50 additions & 0 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
'Resize', 'RandomSizedCrop', 'RandomBrightnessContrast',
'RandomCropNearBBox', 'RandomSizedBBoxSafeCrop', 'RandomSnow',
'RandomRain', 'RandomFog', 'RandomSunFlare', 'RandomShadow', 'Lambda',
'ChannelDropout',
]


Expand Down Expand Up @@ -1787,6 +1788,55 @@ def get_transform_init_args_names(self):
return ('clip_limit', 'tile_grid_size')


class ChannelDropout(ImageOnlyTransform):
"""Randomly Drop Channels in the input Image.
Args:
channel_drop_range (int, int): range from which we choose the number of channels to drop.
fill_value : pixel value for the dropped channel.
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
Image types:
uint8, uint16, unit32, float32
"""

def __init__(self, channel_drop_range=(1, 1), fill_value=0, always_apply=False, p=0.5):
super(ChannelDropout, self).__init__(always_apply, p)

self.min_channels = channel_drop_range[0]
self.max_channels = channel_drop_range[1]

assert 1 <= self.min_channels <= self.max_channels

self.fill_value = fill_value

def apply(self, img, channels_to_drop=(0, ), **params):
return F.channel_dropout(img, channels_to_drop, self.fill_value)

def get_params_dependent_on_targets(self, params):
img = params['image']

num_channels = img.shape[-1]

if len(img.shape) == 2 or num_channels == 1:
raise NotImplementedError("Images has one channel. ChannelDropout is not defined.")

if self.max_channels >= num_channels:
raise ValueError("Can not drop all channels in ChannelDropout.")

num_drop_channels = random.randint(self.min_channels, self.max_channels)

channels_to_drop = random.choice(range(num_channels), size=num_drop_channels, replace=False)

return {'channels_to_drop': channels_to_drop}

def get_transform_init_args_names(self):
return ('channel_drop_range', 'fill_value')


class ChannelShuffle(ImageOnlyTransform):
"""Randomly rearrange channels of the input RGB image.
Expand Down
9 changes: 7 additions & 2 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
IAASharpen, IAAAdditiveGaussianNoise, IAAPiecewiseAffine, IAAPerspective,
Cutout, CoarseDropout, Normalize, ToFloat, FromFloat,
RandomBrightnessContrast, RandomSnow, RandomRain, RandomFog,
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop)
import albumentations as A
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop,
ChannelDropout)


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
Expand All @@ -37,6 +37,7 @@
[RandomFog, {}],
[RandomSunFlare, {}],
[RandomShadow, {}],
[ChannelDropout, {}],
])
def test_image_only_augmentations(augmentation_cls, params, image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down Expand Up @@ -68,6 +69,7 @@ def test_image_only_augmentations(augmentation_cls, params, image, mask):
[RandomFog, {}],
[RandomSunFlare, {}],
[RandomShadow, {}],
[ChannelDropout, {}],
])
def test_image_only_augmentations_with_float_values(augmentation_cls, params, float_image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down Expand Up @@ -182,6 +184,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
[RandomFog, {}],
[RandomSunFlare, {}],
[RandomShadow, {}],
[ChannelDropout, {}],
])
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
image_copy = image.copy()
Expand Down Expand Up @@ -231,6 +234,7 @@ def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
[RandomFog, {}],
[RandomSunFlare, {}],
[RandomShadow, {}],
[ChannelDropout, {}],
])
def test_augmentations_wont_change_float_input(augmentation_cls, params, float_image):
float_image_copy = float_image.copy()
Expand Down Expand Up @@ -326,6 +330,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
[RandomFog, {}],
[RandomSunFlare, {}],
[RandomShadow, {}],
[ChannelDropout, {}],
])
def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down
1 change: 1 addition & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def test_force_apply():
[A.Normalize, {}],
[A.ToFloat, {}],
[A.FromFloat, {}],
[A.ChannelDropout, {}],
])
def test_additional_targets_for_image_only(augmentation_cls, params):
aug = A.Compose(
Expand Down

0 comments on commit f2afcda

Please sign in to comment.