diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b84c3a0aa8..244b910786 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -63,6 +63,7 @@ Guidelines for modifications: * Rosario Scalise * Shafeef Omar * Vladimir Fokow +* Wei Yang * Xavier Nal * Yang Jin * Zhengyu Zhang diff --git a/source/standalone/workflows/rsl_rl/train.py b/source/standalone/workflows/rsl_rl/train.py index f02e0a3c0f..6c73798315 100644 --- a/source/standalone/workflows/rsl_rl/train.py +++ b/source/standalone/workflows/rsl_rl/train.py @@ -99,6 +99,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # create isaac environment env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) + + # save resume path before creating a new log_dir + if agent_cfg.resume: + resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) + # wrap for video recording if args_cli.video: video_kwargs = { @@ -122,10 +127,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device) # write git state to logs runner.add_git_repo_to_log(__file__) - # save resume path before creating a new log_dir + # load the checkpoint if agent_cfg.resume: - # get path to previous checkpoint - resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) print(f"[INFO]: Loading model checkpoint from: {resume_path}") # load previously trained model runner.load(resume_path)