We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent df627a6 commit 0eb6763Copy full SHA for 0eb6763
tunix/rl/experimental/agentic_rl_learner.py
@@ -511,10 +511,8 @@ def _batch_to_train_example(
511
"""
512
# Create a merged training_input where each field from the original input
513
# is repeated G times to align with the G completions.
514
- num_generations = self.algo_config.num_generations
515
- prompt_index = batch_results[0].pair_index // num_generations
516
- if mode == rl_cluster_lib.Mode.TRAIN and self._full_batch_size:
517
- expected_step = prompt_index // self._full_batch_size
+ if mode == rl_cluster_lib.Mode.TRAIN:
+ expected_step = batch_results[0].group_id // self._full_batch_size
518
else:
519
expected_step = self.rl_cluster.global_steps
520
0 commit comments