Skip to content

Commit ea8dd38

Browse files
committed
Fix #1595: pass rollout_id explicitly to offload_train
1 parent 7014942 commit ea8dd38

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def train(args):
3535
if args.num_rollout == 0 and args.eval_interval is not None:
3636
ray.get(rollout_manager.eval.remote(rollout_id=0))
3737

38-
def offload_train():
38+
def offload_train(rollout_id):
3939
if args.offload_train:
4040
if args.use_critic:
4141
critic_model.offload()
@@ -82,7 +82,7 @@ def save(rollout_id):
8282
if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout):
8383
save(rollout_id)
8484

85-
offload_train()
85+
offload_train(rollout_id)
8686
if args.offload_rollout:
8787
ray.get(rollout_manager.onload_weights.remote())
8888
actor_model.update_weights()

0 commit comments

Comments
 (0)