Skip to content

Commit

Permalink
Serialization part 2 (albumentations-team#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz authored Jun 7, 2019
1 parent a7f085d commit 372d9d6
Show file tree
Hide file tree
Showing 12 changed files with 723 additions and 49 deletions.
2 changes: 1 addition & 1 deletion albumentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import absolute_import

__version__ = '0.2.3'
__version__ = '0.3.0'

from .core.composition import *
from .core.transforms_interface import *
Expand Down
23 changes: 19 additions & 4 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import absolute_import, division

from types import LambdaType

import math
import random
import warnings
Expand All @@ -12,6 +11,7 @@
from . import functional as F
from .bbox_utils import union_of_bboxes, denormalize_bbox, normalize_bbox
from ..core.transforms_interface import to_tuple, DualTransform, ImageOnlyTransform, NoOp
from ..core.utils import format_args

__all__ = [
'Blur', 'VerticalFlip', 'HorizontalFlip', 'Flip', 'Normalize', 'Transpose',
Expand Down Expand Up @@ -1968,9 +1968,10 @@ class Lambda(NoOp):
"""

def __init__(self, image=None, mask=None, keypoint=None, bbox=None, always_apply=False, p=1.0):
def __init__(self, image=None, mask=None, keypoint=None, bbox=None, name=None, always_apply=False, p=1.0):
super(Lambda, self).__init__(always_apply, p)

self.name = name
self.custom_apply_fns = {target_name: F.noop for target_name in ('image', 'mask', 'keypoint', 'bbox')}
for target_name, custom_apply_fn in {'image': image, 'mask': mask, 'keypoint': keypoint, 'bbox': bbox}.items():
if custom_apply_fn is not None:
Expand All @@ -1996,5 +1997,19 @@ def apply_to_keypoint(self, keypoint, **params):
fn = self.custom_apply_fns['keypoint']
return fn(keypoint, **params)

def to_dict(self):
raise NotImplementedError('Lambda is not serializable')
def _to_dict(self):
if self.name is None:
raise ValueError(
"To make a Lambda transform serializable you should provide the `name` argument, "
"e.g. `Lambda(name='my_transform', image=<some func>, ...)`."
)
return {
'__type__': 'Lambda',
'__name__': self.name,
}

def __repr__(self):
state = {'name': self.name}
state.update(self.custom_apply_fns.items())
state.update(self.get_base_init_args())
return '{name}({args})'.format(name=self.__class__.__name__, args=format_args(state))
48 changes: 24 additions & 24 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from albumentations.core.serialization import SerializableMeta
from albumentations.core.six import add_metaclass
from albumentations.core.transforms_interface import DualTransform
from albumentations.core.utils import format_args
from albumentations.imgaug.transforms import DualIAATransform
from albumentations.augmentations.bbox_utils import convert_bboxes_from_albumentations, \
convert_bboxes_to_albumentations, filter_bboxes, check_bboxes
Expand Down Expand Up @@ -67,6 +68,7 @@ def __repr__(self):
return self.indented_repr()

def indented_repr(self, indent=REPR_INDENT_STEP):
args = {k: v for k, v in self._to_dict().items() if not (k.startswith('__') or k == 'transforms')}
repr_string = self.__class__.__name__ + '(['
for t in self.transforms:
repr_string += '\n'
Expand All @@ -75,18 +77,18 @@ def indented_repr(self, indent=REPR_INDENT_STEP):
else:
t_repr = repr(t)
repr_string += ' ' * indent + t_repr + ','
repr_string += '\n' + ' ' * (indent - REPR_INDENT_STEP) + '], p={p})'.format(p=self.p)
repr_string += '\n' + ' ' * (indent - REPR_INDENT_STEP) + '], {args})'.format(args=format_args(args))
return repr_string

@classmethod
def get_class_fullname(cls):
return '{cls.__module__}.{cls.__name__}'.format(cls=cls)

def to_dict(self):
def _to_dict(self):
return {
'__class_fullname__': self.get_class_fullname(),
'p': self.p,
'transforms': [t.to_dict() for t in self.transforms]
'transforms': [t._to_dict() for t in self.transforms],
}

def add_targets(self, additional_targets):
Expand Down Expand Up @@ -120,30 +122,19 @@ class Compose(BaseCompose):
| to remain this box in list. Default: 0.0.
"""

def __init__(self, transforms, preprocessing_transforms=[], postprocessing_transforms=[],
to_tensor=None, bbox_params={}, keypoint_params={}, additional_targets={}, p=1.0):
if preprocessing_transforms:
warnings.warn("preprocessing transforms are deprecated, use always_apply flag for this purpose. "
"will be removed in 0.3.0", DeprecationWarning)
set_always_apply(preprocessing_transforms)
if postprocessing_transforms:
warnings.warn("postprocessing transforms are deprecated, use always_apply flag for this purpose"
"will be removed in 0.3.0", DeprecationWarning)
set_always_apply(postprocessing_transforms)
if to_tensor is not None:
warnings.warn("to_tensor in Compose is deprecated, use always_apply flag for this purpose"
"will be removed in 0.3.0", DeprecationWarning)
to_tensor.always_apply = True
# todo deprecated
_transforms = (preprocessing_transforms +
[t for t in transforms if t is not None] +
postprocessing_transforms)
if to_tensor is not None:
_transforms.append(to_tensor)
super(Compose, self).__init__(_transforms, p)
def __init__(self, transforms, bbox_params=None, keypoint_params=None, additional_targets=None, p=1.0):
super(Compose, self).__init__([t for t in transforms if t is not None], p)

if bbox_params is None:
bbox_params = {}
if keypoint_params is None:
keypoint_params = {}
if additional_targets is None:
additional_targets = {}

self.bboxes_name = 'bboxes'
self.keypoints_name = 'keypoints'
self.additional_targets = additional_targets
self.params = {
self.bboxes_name: bbox_params,
self.keypoints_name: keypoint_params
Expand Down Expand Up @@ -212,6 +203,15 @@ def __call__(self, force_apply=False, **data):

return data

def _to_dict(self):
dictionary = super(Compose, self)._to_dict()
dictionary.update({
'bbox_params': self.params[self.bboxes_name],
'keypoint_params': self.params[self.keypoints_name],
'additional_targets': self.additional_targets,
})
return dictionary


def data_postprocessing(data_name, params, check_fn, filter_fn, convert_fn, data):
rows, cols = data['image'].shape[:2]
Expand Down
33 changes: 28 additions & 5 deletions albumentations/core/serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import absolute_import

import json
import warnings

Expand Down Expand Up @@ -47,7 +49,7 @@ def to_dict(transform, on_not_implemented_error='raise'):
)
)
try:
transform_dict = transform.to_dict()
transform_dict = transform._to_dict()
except NotImplementedError as e:
if on_not_implemented_error == 'raise':
raise e
Expand All @@ -67,17 +69,34 @@ def to_dict(transform, on_not_implemented_error='raise'):
}


def from_dict(transform_dict):
def from_dict(transform_dict, lambda_transforms=None):
"""
Args:
transform (dict): A dictionary with serialized transform pipeline.
lambda_transforms (dict): A dictionary that contains lambda transforms, that is instances of the Lambda class.
This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys
in that dictionary should be named same as `name` arguments in respective lambda transforms from
a serialized pipeline.
"""
transform = transform_dict['transform']
if transform.get('__type__') == 'Lambda':
name = transform['__name__']
if lambda_transforms is None:
raise ValueError(
'To deserialize a Lambda transform with name {name} you need to pass a dict with this transform '
'as the `lambda_transforms` argument'.format(name=name)
)
transform = lambda_transforms.get(name)
if transform is None:
raise ValueError('Lambda transform with {name} was not found in `lambda_transforms`'.format(name=name))
return transform
name = transform['__class_fullname__']
args = {k: v for k, v in transform.items() if k != '__class_fullname__'}
cls = SERIALIZABLE_REGISTRY[name]
if 'transforms' in args:
args['transforms'] = [from_dict({'transform': t}) for t in args['transforms']]
args['transforms'] = [
from_dict({'transform': t}, lambda_transforms=lambda_transforms) for t in args['transforms']
]
return cls(**args)


Expand Down Expand Up @@ -108,17 +127,21 @@ def save(transform, filepath, data_format='json', on_not_implemented_error='rais
dump_fn(transform_dict, f)


def load(filepath, data_format='json'):
def load(filepath, data_format='json', lambda_transforms=None):
"""
Load a serialized pipeline from a json or yaml file and construct a transform pipeline.
Args:
transform (obj): Transform to serialize.
filepath (str): Filepath to read from.
data_format (str): Serialization format. Should be either `json` or 'yaml'.
lambda_transforms (dict): A dictionary that contains lambda transforms, that is instances of the Lambda class.
This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys
in that dictionary should be named same as `name` arguments in respective lambda transforms from
a serialized pipeline.
"""
check_data_format(data_format)
load_fn = json.load if data_format == 'json' else yaml.safe_load
with open(filepath) as f:
transform_dict = load_fn(f)
return from_dict(transform_dict)
return from_dict(transform_dict, lambda_transforms=lambda_transforms)
13 changes: 13 additions & 0 deletions albumentations/core/six.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import absolute_import

import sys


PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3

if PY3:
string_types = str,
else:
string_types = basestring, # noqa: F821


def add_metaclass(metaclass):
"""Class decorator for creating a class with a metaclass."""
Expand Down
8 changes: 5 additions & 3 deletions albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import absolute_import

import random

import cv2

from albumentations.core.serialization import SerializableMeta
from albumentations.core.six import add_metaclass
from albumentations.core.utils import format_args

__all__ = ['to_tuple', 'BasicTransform', 'DualTransform', 'ImageOnlyTransform', 'NoOp']

Expand Down Expand Up @@ -68,8 +71,7 @@ def __call__(self, force_apply=False, **kwargs):
def __repr__(self):
state = self.get_base_init_args()
state.update(self.get_transform_init_args())
args = ', '.join(['{0}={1}'.format(k, v) for k, v in state.items()])
return '{name}({args})'.format(name=self.__class__.__name__, args=args)
return '{name}({args})'.format(name=self.__class__.__name__, args=format_args(state))

def _get_target_function(self, key):
transform_key = key
Expand Down Expand Up @@ -142,7 +144,7 @@ def get_base_init_args(self):
def get_transform_init_args(self):
return {k: getattr(self, k) for k in self.get_transform_init_args_names()}

def to_dict(self):
def _to_dict(self):
state = {
'__class_fullname__': self.get_class_fullname(),
}
Expand Down
12 changes: 12 additions & 0 deletions albumentations/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import absolute_import

from ..core.six import string_types


def format_args(args_dict):
formatted_args = []
for k, v in args_dict.items():
if isinstance(v, string_types):
v = "'{}'".format(v)
formatted_args.append('{}={}'.format(k, v))
return ', '.join(formatted_args)
464 changes: 464 additions & 0 deletions notebooks/serialization.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def mask():

@pytest.fixture
def bboxes():
return [[15, 12, 75, 30], [55, 25, 90, 90]]
return [[15, 12, 75, 30, 1], [55, 25, 90, 90, 2]]


@pytest.fixture
def keypoints():
return [[20, 30, 40, 50], [20, 30, 60, 80]]
return [[20, 30, 40, 50, 1], [20, 30, 60, 80, 2]]


@pytest.fixture
Expand Down
2 changes: 2 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Cutout, CoarseDropout, Normalize, ToFloat, FromFloat,
RandomBrightnessContrast, RandomSnow, RandomRain, RandomFog,
RandomSunFlare, RandomCropNearBBox, RandomShadow, RandomSizedCrop)
import albumentations as A


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
Expand Down Expand Up @@ -140,6 +141,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[Cutout, {}],
[JpegCompression, {}],
[HueSaturationValue, {}],
[RGBShift, {}],
Expand Down
10 changes: 0 additions & 10 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,6 @@ def test_compose():
assert second.called


def test_compose_to_tensor():
first = MagicMock()
second = MagicMock()
to_tensor = MagicMock()
augmentation = Compose([first, second], to_tensor=to_tensor, p=0)
image = np.ones((8, 8))
augmentation(image=image)
assert to_tensor.called


def oneof_always_apply_crash():
aug = Compose([
HorizontalFlip(),
Expand Down
Loading

0 comments on commit 372d9d6

Please sign in to comment.