Skip to content

Commit

Permalink
Decouples from flax. (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee authored Mar 19, 2024
1 parent 680130e commit 204f3de
Show file tree
Hide file tree
Showing 8 changed files with 556 additions and 14 deletions.
201 changes: 201 additions & 0 deletions axlearn/common/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright © 2024 Apple Inc.
#
# The code in this file is adapted from:
#
# google/flax:
# Copyright 2024 The Flax Authors.
# Licensed under the Apache License, Version 2.0 (the "License").

"""Adapted from flax.serialization with minor changes."""

import threading
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Type

import jax

_STATE_DICT_REGISTRY: Dict[Any, Any] = {}


class _ErrorContext(threading.local):
"""Context for deserialization error messages."""

def __init__(self):
super().__init__()
self.path = []


_error_context = _ErrorContext()


@contextmanager
def _record_path(name):
try:
_error_context.path.append(name)
yield
finally:
_error_context.path.pop()


def current_path():
"""Current state_dict path during deserialization for error messages."""
return "/".join(_error_context.path)


class _NamedTuple:
"""Fake type marker for namedtuple for registry."""

pass


def _is_namedtuple(x: Any) -> bool:
"""Duck typing test for namedtuple factory-generated objects."""
return isinstance(x, tuple) and hasattr(x, "_fields")


def to_state_dict(target: Any) -> Dict[str, Any]:
"""Returns a dictionary with the state of the given target.
Equivalent to `flax.serialization.to_state_dict`.
Args:
target: The target instance to produce a state dict for.
Returns:
The state dict for the target.
"""
if _is_namedtuple(target):
ty = _NamedTuple
else:
ty = type(target)
if ty not in _STATE_DICT_REGISTRY:
return target

ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0]
state_dict = ty_to_state_dict(target)
if isinstance(state_dict, dict):
for key in state_dict.keys():
if not isinstance(key, str):
raise ValueError(
"A state dict must only have string keys. "
f"Instead, encountered key {key} of type {type(key)}."
)
return state_dict


def from_state_dict(target: Any, state: Dict[str, Any], name: str = ".") -> Any:
"""Restores the state of the given target using a state dict.
Equivalent to `flax.serialization.from_state_dict`.
Args:
target: The object of which the state should be restored.
state: A dictionary generated by `to_state_dict` with the desired new state for `target`.
name: Name of branch taken, used to improve deserialization error messages.
Returns:
A copy of the object with the restored state.
"""
if _is_namedtuple(target):
ty = _NamedTuple
else:
ty = type(target)
if ty not in _STATE_DICT_REGISTRY:
return state
ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1]
with _record_path(name):
return ty_from_state_dict(target, state)


def register_serialization_state(
ty: Type, ty_to_state_dict: Callable, ty_from_state_dict: Callable, override: bool = False
):
"""Register a type for serialization.
Equivalent to `flax.serialization.from_state_dict`.
Args:
ty: The type to be registered.
ty_to_state_dict: A function that takes an instance of `ty` and returns its state as a
dictionary.
ty_from_state_dict: A function that takes an instance of `ty` and a state dict, and returns
a copy of the instance with the restored state.
override: Whether to override a previously registered serialization handler.
"""
if ty in _STATE_DICT_REGISTRY and not override:
raise ValueError(f'A serialization handler for "{ty.__name__}" is already registered.')
_STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict)


# Below are serialization implementations for standard container types.


def _list_state_dict(xs: List[Any]) -> Dict[str, Any]:
return {str(i): to_state_dict(x) for i, x in enumerate(xs)}


def _restore_list(xs: List[Any], state_dict: Dict[str, Any]) -> List[Any]:
if len(state_dict) != len(xs):
raise ValueError(
"The size of the list and the state dict do not match, "
f"got {len(xs)} and {len(state_dict)} at path {current_path()}"
)
return [from_state_dict(xs[i], state_dict[str(i)], name=str(i)) for i in range(len(xs))]


def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]:
str_keys = set(str(k) for k in xs.keys())
if len(str_keys) != len(xs):
raise ValueError(
"Dict keys do not have a unique string representation: " f"{str_keys} vs given: {xs}"
)
return {str(key): to_state_dict(value) for key, value in xs.items()}


def _restore_dict(xs: Dict[str, Any], state_dict: Dict[str, Any]) -> Dict[str, Any]:
diff = set(str(k) for k in xs.keys()).difference(state_dict.keys())
if diff:
raise ValueError(
"The target dict keys and state dict keys do not match, target dict "
f"contains keys {diff} which are not present in state dict at path "
f"{current_path()}"
)
return {
key: from_state_dict(value, state_dict[str(key)], name=str(key))
for key, value in xs.items()
}


def _namedtuple_state_dict(xs) -> Dict[str, Any]:
return {key: to_state_dict(getattr(xs, key)) for key in xs._fields}


def _restore_namedtuple(xs, state_dict: Dict[str, Any]):
state_keys = set(state_dict.keys())
namedtuple_keys = set(xs._fields)
if state_keys != namedtuple_keys:
raise ValueError(
"The field names of the state dict and the named tuple do not match, "
f"got {state_keys} and {namedtuple_keys} at path {current_path()}"
)
fields = {k: from_state_dict(getattr(xs, k), v, name=k) for k, v in state_dict.items()}
return type(xs)(**fields)


register_serialization_state(dict, _dict_state_dict, _restore_dict)
register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(
tuple,
_list_state_dict,
lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)),
)
register_serialization_state(_NamedTuple, _namedtuple_state_dict, _restore_namedtuple)
register_serialization_state(
jax.tree_util.Partial,
lambda x: ({"args": to_state_dict(x.args), "keywords": to_state_dict(x.keywords)}),
lambda x, sd: jax.tree_util.Partial(
x.func,
*from_state_dict(x.args, sd["args"]),
**from_state_dict(x.keywords, sd["keywords"]),
),
)
129 changes: 129 additions & 0 deletions axlearn/common/serialization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright © 2024 Apple Inc.
#
# The code in this file is adapted from:
#
# google/flax:
# Copyright 2024 The Flax Authors.
# Licensed under the Apache License, Version 2.0 (the "License").

"""Tests for serialization utils."""

from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl.testing import parameterized
from jax.tree_util import Partial

from axlearn.common import serialization, struct


@struct.dataclass
class _Point:
x: float
y: float
meta: Any = struct.field(pytree_node=False)


@struct.dataclass
class _Box:
value: int


def _to_state_dict(box: _Box):
return {"value": box.value}


def _from_state_dict(box: _Box, state: Any):
return box.replace(value=state["value"])


serialization.register_serialization_state(_Box, _to_state_dict, _from_state_dict, override=True)


class SerializationTest(parameterized.TestCase):
def test_dataclass_serialization(self):
p = _Point(x=1, y=2, meta={"dummy": True})
state_dict = serialization.to_state_dict(p)
self.assertEqual(state_dict, {"x": 1, "y": 2})
restored_p = serialization.from_state_dict(p, {"x": 3, "y": 4})
expected_p = _Point(x=3, y=4, meta={"dummy": True})
self.assertEqual(restored_p, expected_p)

with self.assertRaises(ValueError): # Invalid field.
serialization.from_state_dict(p, {"z": 3})
with self.assertRaises(ValueError): # Missing field.
serialization.from_state_dict(p, {"x": 3})

def test_pass_through_serialization(self):
p = _Box(value=123)
state_dict = serialization.to_state_dict(p)
self.assertEqual(state_dict, {"value": 123})
restored_box = serialization.from_state_dict(p, state_dict)
expected_box = _Box(value=123)
self.assertEqual(restored_box, expected_box)

def test_model_serialization(self):
initial_params = {
"params": {
"kernel": jnp.array([[1.0]], dtype=jnp.float32),
"bias": jnp.array([0.0], dtype=jnp.float32),
}
}
state = serialization.to_state_dict(initial_params)
self.assertEqual(state, {"params": {"kernel": np.ones((1, 1)), "bias": np.zeros((1,))}})
state = {"params": {"kernel": np.zeros((1, 1)), "bias": np.zeros((1,))}}
restored_model = serialization.from_state_dict(initial_params, state)
self.assertEqual(restored_model, state)

def test_partial_serialization(self):
add_one = Partial(jnp.add, 1)
state = serialization.to_state_dict(add_one)
self.assertEqual(state, {"args": {"0": 1}, "keywords": {}})
restored_add_one = serialization.from_state_dict(add_one, state)
self.assertEqual(add_one.args, restored_add_one.args)

def test_optimizer_serialization(self):
initial_params = {
"params": {
"kernel": jnp.array([[1.0]], dtype=jnp.float32),
"bias": jnp.array([0.0], dtype=jnp.float32),
}
}
tx = optax.sgd(0.1, momentum=0.1)
tx_state = tx.init(initial_params)
state = serialization.to_state_dict(tx_state)
expected_state = {
"0": {
"trace": {
"params": {
"bias": np.array([0.0], dtype=jnp.float32),
"kernel": np.array([[0.0]], dtype=jnp.float32),
}
}
},
"1": {},
}
self.assertEqual(state, expected_state)
state = jax.tree_util.tree_map(lambda x: x + 1, expected_state)
restored_tx_state = serialization.from_state_dict(tx_state, state)
tx_state_plus1 = jax.tree_util.tree_map(lambda x: x + 1, tx_state)
self.assertEqual(restored_tx_state, tx_state_plus1)

def test_collection_serialization(self):
@struct.dataclass
class DummyDataClass:
x: float

@classmethod
def initializer(cls, shape):
del shape
return cls(x=0.0) # pytype: disable=wrong-keyword-args

variables = {"state": {"dummy": DummyDataClass(x=2.0)}}
serialized_state_dict = serialization.to_state_dict(variables)
self.assertEqual(serialized_state_dict, {"state": {"dummy": {"x": 2.0}}})
deserialized_state = serialization.from_state_dict(variables, serialized_state_dict)
self.assertEqual(variables, deserialized_state)
Loading

0 comments on commit 204f3de

Please sign in to comment.