From f879aa6a80e2b9e6dba1263e8f5938ccbaa65efc Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Sun, 13 Oct 2024 03:40:10 -0700 Subject: [PATCH] Fixes the checkpoint loading error in RSL-RL training script (#1210) # Description An error of `No checkpoints in the directory` will throw when resume from a previous training with `--video` set. This is because a new log folder will be created before the check. This MR fixes this issue by loading the checkpoint before. Fixes #1209 ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there --- CONTRIBUTORS.md | 1 + source/standalone/workflows/rsl_rl/train.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b84c3a0aa8..244b910786 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -63,6 +63,7 @@ Guidelines for modifications: * Rosario Scalise * Shafeef Omar * Vladimir Fokow +* Wei Yang * Xavier Nal * Yang Jin * Zhengyu Zhang diff --git a/source/standalone/workflows/rsl_rl/train.py b/source/standalone/workflows/rsl_rl/train.py index f02e0a3c0f..6c73798315 100644 --- a/source/standalone/workflows/rsl_rl/train.py +++ b/source/standalone/workflows/rsl_rl/train.py @@ -99,6 +99,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # create isaac environment env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) + + # save resume path before creating a new log_dir + if agent_cfg.resume: + resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) + # wrap for video recording if args_cli.video: video_kwargs = { @@ -122,10 +127,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device) # write git state to logs runner.add_git_repo_to_log(__file__) - # save resume path before creating a new log_dir + # load the checkpoint if agent_cfg.resume: - # get path to previous checkpoint - resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) print(f"[INFO]: Loading model checkpoint from: {resume_path}") # load previously trained model runner.load(resume_path)