Skip to content

Commit 2f12bb9

Browse files
ooctipuskellyguo11
andauthored
Enables sb3 to load checkpoint to continue training (isaac-sim#2954)
# Description This PR extend `script/reinforcement_learning/sb3/train.py` with feature to continue learning by loading the checkpoint. ## Type of change <!-- As you go through the list, delete the ones that are not applicable. --> - New feature (non-breaking change which adds functionality) ## Screenshots Please attach before and after screenshots of the change if applicable. <!-- Example: | Before | After | | ------ | ----- | | _gif/png before_ | _gif/png after_ | To upload images to a PR -- simply drag and drop an image while in edit mode and it should upload the image directly. You can then paste that source into the above before/after sections. --> ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there <!-- As you go through the checklist above, you can mark something as done by putting an x character in it For example, - [x] I have done this task - [ ] I have not done this task --> --------- Co-authored-by: Kelly Guo <kellyg@nvidia.com> Co-authored-by: Kelly Guo <kellyguo123@hotmail.com>
1 parent ed8fe3c commit 2f12bb9

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

  • scripts/reinforcement_learning/sb3

scripts/reinforcement_learning/sb3/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
2626
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
2727
parser.add_argument("--log_interval", type=int, default=100_000, help="Log data every n timesteps.")
28+
parser.add_argument("--checkpoint", type=str, default=None, help="Continue the training from checkpoint.")
2829
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
2930
parser.add_argument(
3031
"--keep_all_info",
@@ -179,6 +180,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
179180

180181
# create agent from stable baselines
181182
agent = PPO(policy_arch, env, verbose=1, tensorboard_log=log_dir, **agent_cfg)
183+
if args_cli.checkpoint is not None:
184+
agent = agent.load(args_cli.checkpoint, env, print_system_info=True)
182185

183186
# callbacks for agent
184187
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_dir, name_prefix="model", verbose=2)

0 commit comments

Comments
 (0)