Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug fix] Jax rollback for TPU #855

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
step_prefix=STEP_PREFIX,
step_format_fixed_length=STEP_NUM_DIGITS,
)
# TODO(matthew_e_hopkins): bring back save_concurrent_gb and restore_concurrent_gb
# after bumping up the Jax version.
if cfg.max_concurrent_restore_gb is not None:
raise NotImplementedError(
"Orbax version (0.5.23) doesn't support separate save/restore concurrent_gb."
)
self._manager = ocp.CheckpointManager(
directory=cfg.dir,
options=ocp.CheckpointManagerOptions(
Expand All @@ -245,8 +251,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
# Note that this defaults to use_ocdb=True. Note also that custom `TypeHandler`s are
# ignored by `StandardCheckpointHandler`, so we use `PyTreeCheckpointHandler`.
"state": ocp.PyTreeCheckpointHandler(
save_concurrent_gb=cfg.max_concurrent_save_gb,
restore_concurrent_gb=cfg.max_concurrent_restore_gb,
concurrent_gb=cfg.max_concurrent_save_gb,
),
},
)
Expand Down
41 changes: 27 additions & 14 deletions axlearn/common/checkpointer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,33 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]):
x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([3], dtype=jnp.float32)
),
)

# When the given state has a different dict shape: [1] instead of [] for x.
# Orbax throws AssertionError in this case.
with self.assertRaisesRegex(
(AssertionError, ValueError),
"(checkpoint tree dtypes or shapes|not compatible)",
):
ckpt.restore(
step=None,
state=dict(
x=jnp.zeros([1], dtype=jnp.int32),
y=jnp.ones([2], dtype=jnp.float32),
),
)
# TODO(matthew_e_hopkins): revert it once upgrade jax version.
if checkpointer_cls is Checkpointer:
# When the given state has a different dict shape: [1] instead of [] for x.
# Orbax throws AssertionError in this case.
with self.assertRaisesRegex(
(AssertionError, ValueError),
"(checkpoint tree dtypes or shapes|not compatible)",
):
ckpt.restore(
step=None,
state=dict(
x=jnp.zeros([1], dtype=jnp.int32),
y=jnp.ones([2], dtype=jnp.float32),
),
)
else:
with self.assertRaisesRegex(
(AssertionError, ValueError),
"Cannot intersect index domain",
):
ckpt.restore(
step=None,
state=dict(
x=jnp.zeros([1], dtype=jnp.int32),
y=jnp.ones([2], dtype=jnp.float32),
),
)

# When the given state has a different dtype: float32 instead of int32 for x.
with self.assertRaisesRegex(ValueError, "checkpoint tree dtypes or shapes"):
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ core = [
"chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25.
"einops==0.8.0",
"importlab==0.7", # breaks pytype on 0.8
"jax==0.4.34",
"jaxlib==0.4.34",
"jax==0.4.33",
"jaxlib==0.4.33",
"nltk==3.7", # for text preprocessing
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
"portpicker",
Expand Down Expand Up @@ -102,7 +102,7 @@ gcp = [
# Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install.
tpu = [
"axlearn[gcp]",
"jax[tpu]==0.4.34", # must be >=0.4.19 for compat with v5p.
"jax[tpu]==0.4.33", # must be >=0.4.19 for compat with v5p.
]
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
vertexai_tensorboard = [
Expand All @@ -126,7 +126,7 @@ dataflow = [
# GPU custom kernel dependency.
gpu = [
"triton==2.1.0",
"jax[cuda12]==0.4.34",
"jax[cuda12]==0.4.33",
]
# Open API inference.
open_api = [
Expand All @@ -146,7 +146,7 @@ mmau = [
# Orbax checkpointing.
orbax = [
"humanize==4.10.0",
"orbax-checkpoint==0.9.1",
"orbax-checkpoint==0.5.23",
]
# Grain input processing. Currently does not support macos.
grain = [
Expand Down
Loading