Skip to content

Commit

Permalink
in mem ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Nov 12, 2024
1 parent 14054b4 commit 0e74dc8
Show file tree
Hide file tree
Showing 6 changed files with 984 additions and 22 deletions.
30 changes: 29 additions & 1 deletion axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,8 @@ class Config(Module.Config):
every_n_steps_policy
)

# TODO(hanzhi-zhou): deprecate all checkpoint_paths related class methods in favor of
# checkpoint_steps.
@classmethod
def checkpoint_paths(cls, base_dir: str) -> list[str]:
"""Returns complete checkpoint paths under base dir.
Expand All @@ -744,6 +746,24 @@ def latest_checkpoint_path(cls, base_dir: str) -> str:
# Note: checkpoint_paths should already filter incomplete checkpoints.
return sorted(cls.checkpoint_paths(base_dir)).pop()

@classmethod
def checkpoint_steps(cls, base_dir: str) -> list[int]:
"""Returns complete checkpoint steps under base dir.
Args:
base_dir: Path to checkpoints dir.
Returns:
A list of committed checkpoint steps. Incomplete checkpoints are dropped.
"""
raise NotImplementedError(cls)

@classmethod
def latest_checkpoint_step(cls, base_dir: str) -> int:
"""Returns the most recent (highest step count) checkpoint step under base dir."""
# Note: checkpoint_steps should already filter incomplete checkpoints.
return max(cls.checkpoint_steps(base_dir))

def __init__(self, cfg: Module.Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
self._within_context = False
Expand Down Expand Up @@ -850,7 +870,11 @@ class Config(BaseCheckpointer.Config):
@classmethod
def checkpoint_paths(cls, base_dir: str) -> list[str]:
"""See `BaseCheckpointer.checkpointer_paths`."""

logging.log_first_n(
logging.WARNING,
msg="checkpoint_paths is deprecated. Use checkpoint_steps instead.",
n=1,
)
# The default checkpointer commits under "<base_dir>/<step_prefix>_<step>/index". Using a
# concurrent `exists` check for the index file can be several times faster than `glob` on
# gcs when there are many checkpoint files, even if using a "native" solution like
Expand All @@ -867,6 +891,10 @@ def checkpoint_paths(cls, base_dir: str) -> list[str]:
index_exists = pool.map(fs.exists, paths)
return [os.path.dirname(path) for path, committed in zip(paths, index_exists) if committed]

@classmethod
def checkpoint_steps(cls, base_dir: str) -> list[int]:
return [parse_step_from_dir(path) for path in cls.checkpoint_paths(base_dir)]

@classmethod
def cleanup_checkpoint(cls, ckpt_dir: str, *, sync: bool = True):
"""Removes ckpt_dir if it exists.
Expand Down
Loading

0 comments on commit 0e74dc8

Please sign in to comment.