Skip to content

Commit a93fc65

Browse files
hgao327The tunix Authors
authored andcommitted
Allow configurable micro-batching for compute_logps in agentic RL.
PiperOrigin-RevId: 914005102
1 parent 388c8a2 commit a93fc65

2 files changed

Lines changed: 53 additions & 14 deletions

File tree

tests/rl/agentic/agentic_rl_learner_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414

1515
"""Tests for agentic_rl_learner."""
1616

17+
import asyncio
18+
from typing import Any
1719
from unittest import mock
1820

21+
from absl import logging
1922
from absl.testing import absltest
2023
from absl.testing import parameterized
24+
from tunix.rl import rl_cluster as rl_cluster_lib
25+
from tunix.rl import utils as rl_utils
2126
from tunix.rl.agentic import agentic_rl_learner
2227
from tunix.rl.rollout import base_rollout
2328

@@ -135,6 +140,44 @@ def test_validate_rollout_config_vllm_missing_server_mode(self):
135140
algo_config=algo_config,
136141
)
137142

143+
def test_train_batch_size_mismatch_raises_error(self):
144+
with mock.patch.object(
145+
rl_utils, "is_sharing_weights", return_value=False
146+
):
147+
rl_cluster = mock.Mock()
148+
rl_cluster.cluster_config = mock.Mock()
149+
rl_cluster.cluster_config.role_to_mesh = {
150+
rl_cluster_lib.Role.ACTOR: mock.Mock(),
151+
rl_cluster_lib.Role.ROLLOUT: mock.Mock(),
152+
}
153+
training_config = mock.Mock()
154+
training_config.compute_logps_micro_batch_size = 2
155+
training_config.train_micro_batch_size = 1
156+
training_config.mini_batch_size = None
157+
rl_cluster.cluster_config.training_config = training_config
158+
rl_cluster.cluster_config.rollout_config = base_rollout.RolloutConfig(
159+
max_tokens_to_generate=10, return_logprobs=True
160+
)
161+
rl_cluster.cluster_config.rollout_engine = 'generic'
162+
rl_cluster.actor_trainer = mock.Mock()
163+
rl_cluster.actor_trainer.restored_global_step.return_value = 0
164+
rl_cluster.actor_trainer.iter_steps = 0
165+
rl_cluster.rollout = mock.Mock()
166+
rl_cluster.tokenizer = mock.Mock()
167+
algo_config = agentic_rl_learner.AgenticRLConfig(max_response_length=10)
168+
learner = DummyLearner(
169+
rl_cluster=rl_cluster,
170+
reward_fns=mock.Mock(),
171+
algo_config=algo_config,
172+
)
173+
train_dataset = [{'prompt': ['p1']}]
174+
with self.assertRaisesRegex(
175+
ValueError,
176+
r'compute_logps_micro_batch_size \(2\) must be equal to'
177+
r' train_micro_batch_size \(1\)',
178+
):
179+
learner.train(train_dataset)
180+
138181

139182
if __name__ == "__main__":
140183
absltest.main()

tunix/rl/agentic/agentic_rl_learner.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __init__(
211211
self._training_config.rollout_micro_batch_size
212212
)
213213
self._compute_logps_micro_batch_size = (
214-
self._training_config.compute_logps_micro_batch_size
214+
self._training_config.compute_logps_micro_batch_size or 1
215215
)
216216
sft_utils.show_hbm_usage(title="AgenticRLLearner init")
217217

@@ -413,8 +413,6 @@ def _model_call(
413413
if "pair_index" in env.extra_kwargs:
414414
tags[perf_constants.PAIR_INDEX] = env.extra_kwargs["pair_index"]
415415

416-
417-
418416
result = self.rl_cluster.generate(
419417
prompts=chat_lists,
420418
apply_chat_template=False if self.chat_parser else True,
@@ -694,22 +692,20 @@ def train(
694692
train_micro_batch_size = (
695693
self._training_config.train_micro_batch_size or mini_batch_size
696694
)
697-
# Rollout and compute_logps micro batch sizes have to be 1 since we only
698-
# process inidividual prompts.
695+
# Rollout micro batch size has to be 1 since we only process individual
696+
# prompts.
699697
self._rollout_micro_batch_size = 1
700-
701-
compute_logps_mb = self._training_config.compute_logps_micro_batch_size
702698
self._process_in_consumer = False
703-
if compute_logps_mb is not None and compute_logps_mb > 1:
704-
if compute_logps_mb != train_micro_batch_size:
699+
700+
if self._compute_logps_micro_batch_size > 1:
701+
if self._compute_logps_micro_batch_size != train_micro_batch_size:
705702
raise ValueError(
706-
f"compute_logps_micro_batch_size ({compute_logps_mb}) must be"
707-
f" equal to train_micro_batch_size ({train_micro_batch_size})"
703+
"compute_logps_micro_batch_size"
704+
f" ({self._compute_logps_micro_batch_size}) must be equal to"
705+
f" train_micro_batch_size ({train_micro_batch_size})"
708706
)
709707
self._process_in_consumer = True
710-
self._compute_logps_micro_batch_size = compute_logps_mb
711-
else:
712-
self._compute_logps_micro_batch_size = 1
708+
713709
for v, n in [
714710
(self._rollout_micro_batch_size, f"{self._rollout_micro_batch_size=}"),
715711
(

0 commit comments

Comments
 (0)