-
Notifications
You must be signed in to change notification settings - Fork 280
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
556 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Copyright © 2024 Apple Inc. | ||
# | ||
# The code in this file is adapted from: | ||
# | ||
# google/flax: | ||
# Copyright 2024 The Flax Authors. | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
|
||
"""Adapted from flax.serialization with minor changes.""" | ||
|
||
import threading | ||
from contextlib import contextmanager | ||
from typing import Any, Callable, Dict, List, Type | ||
|
||
import jax | ||
|
||
_STATE_DICT_REGISTRY: Dict[Any, Any] = {} | ||
|
||
|
||
class _ErrorContext(threading.local): | ||
"""Context for deserialization error messages.""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.path = [] | ||
|
||
|
||
_error_context = _ErrorContext() | ||
|
||
|
||
@contextmanager | ||
def _record_path(name): | ||
try: | ||
_error_context.path.append(name) | ||
yield | ||
finally: | ||
_error_context.path.pop() | ||
|
||
|
||
def current_path(): | ||
"""Current state_dict path during deserialization for error messages.""" | ||
return "/".join(_error_context.path) | ||
|
||
|
||
class _NamedTuple: | ||
"""Fake type marker for namedtuple for registry.""" | ||
|
||
pass | ||
|
||
|
||
def _is_namedtuple(x: Any) -> bool: | ||
"""Duck typing test for namedtuple factory-generated objects.""" | ||
return isinstance(x, tuple) and hasattr(x, "_fields") | ||
|
||
|
||
def to_state_dict(target: Any) -> Dict[str, Any]: | ||
"""Returns a dictionary with the state of the given target. | ||
Equivalent to `flax.serialization.to_state_dict`. | ||
Args: | ||
target: The target instance to produce a state dict for. | ||
Returns: | ||
The state dict for the target. | ||
""" | ||
if _is_namedtuple(target): | ||
ty = _NamedTuple | ||
else: | ||
ty = type(target) | ||
if ty not in _STATE_DICT_REGISTRY: | ||
return target | ||
|
||
ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0] | ||
state_dict = ty_to_state_dict(target) | ||
if isinstance(state_dict, dict): | ||
for key in state_dict.keys(): | ||
if not isinstance(key, str): | ||
raise ValueError( | ||
"A state dict must only have string keys. " | ||
f"Instead, encountered key {key} of type {type(key)}." | ||
) | ||
return state_dict | ||
|
||
|
||
def from_state_dict(target: Any, state: Dict[str, Any], name: str = ".") -> Any: | ||
"""Restores the state of the given target using a state dict. | ||
Equivalent to `flax.serialization.from_state_dict`. | ||
Args: | ||
target: The object of which the state should be restored. | ||
state: A dictionary generated by `to_state_dict` with the desired new state for `target`. | ||
name: Name of branch taken, used to improve deserialization error messages. | ||
Returns: | ||
A copy of the object with the restored state. | ||
""" | ||
if _is_namedtuple(target): | ||
ty = _NamedTuple | ||
else: | ||
ty = type(target) | ||
if ty not in _STATE_DICT_REGISTRY: | ||
return state | ||
ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1] | ||
with _record_path(name): | ||
return ty_from_state_dict(target, state) | ||
|
||
|
||
def register_serialization_state( | ||
ty: Type, ty_to_state_dict: Callable, ty_from_state_dict: Callable, override: bool = False | ||
): | ||
"""Register a type for serialization. | ||
Equivalent to `flax.serialization.from_state_dict`. | ||
Args: | ||
ty: The type to be registered. | ||
ty_to_state_dict: A function that takes an instance of `ty` and returns its state as a | ||
dictionary. | ||
ty_from_state_dict: A function that takes an instance of `ty` and a state dict, and returns | ||
a copy of the instance with the restored state. | ||
override: Whether to override a previously registered serialization handler. | ||
""" | ||
if ty in _STATE_DICT_REGISTRY and not override: | ||
raise ValueError(f'A serialization handler for "{ty.__name__}" is already registered.') | ||
_STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict) | ||
|
||
|
||
# Below are serialization implementations for standard container types. | ||
|
||
|
||
def _list_state_dict(xs: List[Any]) -> Dict[str, Any]: | ||
return {str(i): to_state_dict(x) for i, x in enumerate(xs)} | ||
|
||
|
||
def _restore_list(xs: List[Any], state_dict: Dict[str, Any]) -> List[Any]: | ||
if len(state_dict) != len(xs): | ||
raise ValueError( | ||
"The size of the list and the state dict do not match, " | ||
f"got {len(xs)} and {len(state_dict)} at path {current_path()}" | ||
) | ||
return [from_state_dict(xs[i], state_dict[str(i)], name=str(i)) for i in range(len(xs))] | ||
|
||
|
||
def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]: | ||
str_keys = set(str(k) for k in xs.keys()) | ||
if len(str_keys) != len(xs): | ||
raise ValueError( | ||
"Dict keys do not have a unique string representation: " f"{str_keys} vs given: {xs}" | ||
) | ||
return {str(key): to_state_dict(value) for key, value in xs.items()} | ||
|
||
|
||
def _restore_dict(xs: Dict[str, Any], state_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
diff = set(str(k) for k in xs.keys()).difference(state_dict.keys()) | ||
if diff: | ||
raise ValueError( | ||
"The target dict keys and state dict keys do not match, target dict " | ||
f"contains keys {diff} which are not present in state dict at path " | ||
f"{current_path()}" | ||
) | ||
return { | ||
key: from_state_dict(value, state_dict[str(key)], name=str(key)) | ||
for key, value in xs.items() | ||
} | ||
|
||
|
||
def _namedtuple_state_dict(xs) -> Dict[str, Any]: | ||
return {key: to_state_dict(getattr(xs, key)) for key in xs._fields} | ||
|
||
|
||
def _restore_namedtuple(xs, state_dict: Dict[str, Any]): | ||
state_keys = set(state_dict.keys()) | ||
namedtuple_keys = set(xs._fields) | ||
if state_keys != namedtuple_keys: | ||
raise ValueError( | ||
"The field names of the state dict and the named tuple do not match, " | ||
f"got {state_keys} and {namedtuple_keys} at path {current_path()}" | ||
) | ||
fields = {k: from_state_dict(getattr(xs, k), v, name=k) for k, v in state_dict.items()} | ||
return type(xs)(**fields) | ||
|
||
|
||
register_serialization_state(dict, _dict_state_dict, _restore_dict) | ||
register_serialization_state(list, _list_state_dict, _restore_list) | ||
register_serialization_state( | ||
tuple, | ||
_list_state_dict, | ||
lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)), | ||
) | ||
register_serialization_state(_NamedTuple, _namedtuple_state_dict, _restore_namedtuple) | ||
register_serialization_state( | ||
jax.tree_util.Partial, | ||
lambda x: ({"args": to_state_dict(x.args), "keywords": to_state_dict(x.keywords)}), | ||
lambda x, sd: jax.tree_util.Partial( | ||
x.func, | ||
*from_state_dict(x.args, sd["args"]), | ||
**from_state_dict(x.keywords, sd["keywords"]), | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright © 2024 Apple Inc. | ||
# | ||
# The code in this file is adapted from: | ||
# | ||
# google/flax: | ||
# Copyright 2024 The Flax Authors. | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
|
||
"""Tests for serialization utils.""" | ||
|
||
from typing import Any | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import optax | ||
from absl.testing import parameterized | ||
from jax.tree_util import Partial | ||
|
||
from axlearn.common import serialization, struct | ||
|
||
|
||
@struct.dataclass | ||
class _Point: | ||
x: float | ||
y: float | ||
meta: Any = struct.field(pytree_node=False) | ||
|
||
|
||
@struct.dataclass | ||
class _Box: | ||
value: int | ||
|
||
|
||
def _to_state_dict(box: _Box): | ||
return {"value": box.value} | ||
|
||
|
||
def _from_state_dict(box: _Box, state: Any): | ||
return box.replace(value=state["value"]) | ||
|
||
|
||
serialization.register_serialization_state(_Box, _to_state_dict, _from_state_dict, override=True) | ||
|
||
|
||
class SerializationTest(parameterized.TestCase): | ||
def test_dataclass_serialization(self): | ||
p = _Point(x=1, y=2, meta={"dummy": True}) | ||
state_dict = serialization.to_state_dict(p) | ||
self.assertEqual(state_dict, {"x": 1, "y": 2}) | ||
restored_p = serialization.from_state_dict(p, {"x": 3, "y": 4}) | ||
expected_p = _Point(x=3, y=4, meta={"dummy": True}) | ||
self.assertEqual(restored_p, expected_p) | ||
|
||
with self.assertRaises(ValueError): # Invalid field. | ||
serialization.from_state_dict(p, {"z": 3}) | ||
with self.assertRaises(ValueError): # Missing field. | ||
serialization.from_state_dict(p, {"x": 3}) | ||
|
||
def test_pass_through_serialization(self): | ||
p = _Box(value=123) | ||
state_dict = serialization.to_state_dict(p) | ||
self.assertEqual(state_dict, {"value": 123}) | ||
restored_box = serialization.from_state_dict(p, state_dict) | ||
expected_box = _Box(value=123) | ||
self.assertEqual(restored_box, expected_box) | ||
|
||
def test_model_serialization(self): | ||
initial_params = { | ||
"params": { | ||
"kernel": jnp.array([[1.0]], dtype=jnp.float32), | ||
"bias": jnp.array([0.0], dtype=jnp.float32), | ||
} | ||
} | ||
state = serialization.to_state_dict(initial_params) | ||
self.assertEqual(state, {"params": {"kernel": np.ones((1, 1)), "bias": np.zeros((1,))}}) | ||
state = {"params": {"kernel": np.zeros((1, 1)), "bias": np.zeros((1,))}} | ||
restored_model = serialization.from_state_dict(initial_params, state) | ||
self.assertEqual(restored_model, state) | ||
|
||
def test_partial_serialization(self): | ||
add_one = Partial(jnp.add, 1) | ||
state = serialization.to_state_dict(add_one) | ||
self.assertEqual(state, {"args": {"0": 1}, "keywords": {}}) | ||
restored_add_one = serialization.from_state_dict(add_one, state) | ||
self.assertEqual(add_one.args, restored_add_one.args) | ||
|
||
def test_optimizer_serialization(self): | ||
initial_params = { | ||
"params": { | ||
"kernel": jnp.array([[1.0]], dtype=jnp.float32), | ||
"bias": jnp.array([0.0], dtype=jnp.float32), | ||
} | ||
} | ||
tx = optax.sgd(0.1, momentum=0.1) | ||
tx_state = tx.init(initial_params) | ||
state = serialization.to_state_dict(tx_state) | ||
expected_state = { | ||
"0": { | ||
"trace": { | ||
"params": { | ||
"bias": np.array([0.0], dtype=jnp.float32), | ||
"kernel": np.array([[0.0]], dtype=jnp.float32), | ||
} | ||
} | ||
}, | ||
"1": {}, | ||
} | ||
self.assertEqual(state, expected_state) | ||
state = jax.tree_util.tree_map(lambda x: x + 1, expected_state) | ||
restored_tx_state = serialization.from_state_dict(tx_state, state) | ||
tx_state_plus1 = jax.tree_util.tree_map(lambda x: x + 1, tx_state) | ||
self.assertEqual(restored_tx_state, tx_state_plus1) | ||
|
||
def test_collection_serialization(self): | ||
@struct.dataclass | ||
class DummyDataClass: | ||
x: float | ||
|
||
@classmethod | ||
def initializer(cls, shape): | ||
del shape | ||
return cls(x=0.0) # pytype: disable=wrong-keyword-args | ||
|
||
variables = {"state": {"dummy": DummyDataClass(x=2.0)}} | ||
serialized_state_dict = serialization.to_state_dict(variables) | ||
self.assertEqual(serialized_state_dict, {"state": {"dummy": {"x": 2.0}}}) | ||
deserialized_state = serialization.from_state_dict(variables, serialized_state_dict) | ||
self.assertEqual(variables, deserialized_state) |
Oops, something went wrong.