Skip to content

Commit e8b6ae7

Browse files
authored
[skyrl-train] Fix num_training_steps for workers being set incorrectly (#873)
`num_training_steps` was being set to the number of training batch steps rather than the number of optimizer (mini-batch) steps, causing learning rate decay to progress too quickly if using a non-constant learning rate scheduler. renames to `num_training_batches` for clarity, since each training batch can contain several optimizer steps. Closes #872
1 parent 4cf9419 commit e8b6ae7

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

skyrl-train/skyrl_train/trainer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,21 +497,29 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
497497
else:
498498
critic_model = None
499499

500+
policy_steps_per_train_batch = (
501+
cfg.trainer.train_batch_size // cfg.trainer.policy_mini_batch_size * cfg.trainer.update_epochs_per_batch
502+
)
503+
critic_steps_per_train_batch = 0
504+
if cfg.trainer.critic.model.path:
505+
critic_steps_per_train_batch = (
506+
cfg.trainer.train_batch_size // cfg.trainer.critic_mini_batch_size * cfg.trainer.update_epochs_per_batch
507+
)
500508
if not cfg.trainer.placement.colocate_all:
501509
refs = []
502510
if ref_model is not None:
503511
refs.extend(ref_model.async_init_model(cfg.trainer.ref.model.path))
504512
refs.extend(
505513
policy_model.async_init_model(
506514
cfg.trainer.policy.model.path,
507-
num_training_steps=self.total_training_steps,
515+
num_training_steps=self.total_training_steps * policy_steps_per_train_batch,
508516
)
509517
)
510518
if cfg.trainer.critic.model.path:
511519
refs.extend(
512520
critic_model.async_init_model(
513521
cfg.trainer.critic.model.path,
514-
num_training_steps=self.total_training_steps,
522+
num_training_steps=self.total_training_steps * critic_steps_per_train_batch,
515523
)
516524
)
517525
ray.get(refs)
@@ -523,7 +531,7 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
523531
ray.get(
524532
policy_model.async_init_model(
525533
cfg.trainer.policy.model.path,
526-
num_training_steps=self.total_training_steps,
534+
num_training_steps=self.total_training_steps * policy_steps_per_train_batch,
527535
)
528536
)
529537
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self.tokenizer.pad_token_id))
@@ -532,7 +540,7 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker):
532540
ray.get(
533541
critic_model.async_init_model(
534542
cfg.trainer.critic.model.path,
535-
num_training_steps=self.total_training_steps,
543+
num_training_steps=self.total_training_steps * critic_steps_per_train_batch,
536544
)
537545
)
538546
critic_model.offload_to_cpu()

0 commit comments

Comments
 (0)