Skip to content

Commit 180f460

Browse files
committed
Add colocated mode to agentic cli.
1 parent 60ce6df commit 180f460

8 files changed

Lines changed: 351 additions & 140 deletions

File tree

examples/deepscaler/run_deepscaler_disagg_v5p16.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ python -m tunix.cli.grpo_main \
6464
model_config.remat_config=3 \
6565
actor_model_config.mesh.shape="$trainer_mesh" \
6666
actor_model_config.mesh.axis_names="('fsdp','tp')" \
67-
reference_model_config.mesh=null \
68-
reference_model_config.same_mesh_as="actor" \
6967
rollout_model_config.mesh.shape="$rollout_mesh" \
7068
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
7169
\

examples/deepswe/run_deepswe_disagg_v5p_32.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ python -m tunix.cli.grpo_main \
8181
model_config.remat_config=3 \
8282
actor_model_config.mesh.shape="$trainer_mesh" \
8383
actor_model_config.mesh.axis_names="('fsdp','tp')" \
84-
reference_model_config.mesh=null \
85-
reference_model_config.same_mesh_as="actor" \
8684
rollout_model_config.mesh.shape="$rollout_mesh" \
8785
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
8886
\

examples/rl/grpo/gsm8k/run_qwen3_8b.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ num_generations="${num_generations:-4}"
4545
train_mesh="${train_mesh:-(8,1)}"
4646
rollout_mesh="${rollout_mesh:-(1,8)}"
4747

48+
# Set rollout_colocate to the mesh name (e.g. "actor") to colocate the rollout
49+
# model on the same mesh as the actor model
50+
rollout_colocate="${rollout_colocate:-null}"
51+
4852
checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}"
4953
checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}"
5054
if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then
@@ -79,8 +83,7 @@ python -m tunix.cli.grpo_main \
7983
model_config.remat_config=3 \
8084
actor_model_config.mesh.shape="$train_mesh" \
8185
actor_model_config.mesh.axis_names="('fsdp','tp')" \
82-
reference_model_config.mesh=null \
83-
reference_model_config.same_mesh_as="actor" \
86+
rollout_model_config.colocate_with="$rollout_colocate" \
8487
rollout_model_config.mesh.shape="$rollout_mesh" \
8588
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
8689
\

tests/cli/grpo_main_test.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,6 @@ def test_cli_empty_system_prompt_stays_empty_string(self):
644644
)
645645
self.assertEqual(p.config["agentic_grpo_config"]["system_prompt"], "")
646646

647-
648647
class SplitMeshConfigTest(absltest.TestCase):
649648

650649
def test_split_mesh_uses_explicit_role_meshes(self):
@@ -688,7 +687,6 @@ def test_split_mesh_uses_explicit_role_meshes(self):
688687
"shape": "(2,1)",
689688
"axis_names": "('fsdp','tp')",
690689
}
691-
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
692690
rollout_model_config = pipeline.config["rollout_model_config"]
693691
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
694692
rollout_model_config["mesh"] = {
@@ -732,6 +730,87 @@ def __init__(self, devices, axis_names, axis_types=None):
732730
role_to_mesh[rl_cluster_lib.Role.ACTOR],
733731
)
734732

733+
def test_colocate_with_reuses_device_slice_with_different_mesh(self):
734+
extra = """
735+
training_mode: "agentic_grpo"
736+
data_module: "tunix.cli.recipes.deepscaler_data"
737+
apply_chat_template_to_dataset: false
738+
data_config:
739+
train_data_path: "gs://fake/train.json"
740+
eval_data_path: "gs://fake/eval.parquet"
741+
prompt_key: "prompts"
742+
reward_functions: []
743+
verl_compatible: false
744+
chat_parser_config:
745+
type: "default"
746+
agent_class_path: null
747+
agent_kwargs: {}
748+
env_class_path: null
749+
env_kwargs: {}
750+
kubernetes_config: null
751+
agentic_grpo_config:
752+
num_generations: 2
753+
num_iterations: 1
754+
beta: 0.0
755+
epsilon: 0.2
756+
epsilon_high: 0.28
757+
system_prompt: ""
758+
max_concurrency: 1
759+
off_policy_steps: 0
760+
max_turns: 1
761+
context_ratio: 1
762+
sglang_jax_config:
763+
mem_fraction_static: 0.8
764+
vllm_config:
765+
hbm_utilization: 0.4
766+
"""
767+
pipeline = _make_pipeline(extra)
768+
actor_model_config = pipeline.config["actor_model_config"]
769+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
770+
actor_model_config["mesh"] = {
771+
"shape": "(2,1)",
772+
"axis_names": "('fsdp','tp')",
773+
}
774+
rollout_model_config = pipeline.config["rollout_model_config"]
775+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
776+
rollout_model_config["colocate_with"] = "actor"
777+
rollout_model_config["mesh"] = {
778+
"shape": "(1,2)",
779+
"axis_names": "('fsdp','tp')",
780+
}
781+
782+
fake_devices = list(range(4))
783+
784+
class FakeMesh:
785+
786+
def __init__(self, devices, axis_names, axis_types=None):
787+
self.devices = devices
788+
self.axis_names = axis_names
789+
self.axis_types = axis_types
790+
791+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
792+
with mock.patch.object(
793+
grpo_main.jax.sharding, "Mesh", side_effect=FakeMesh
794+
):
795+
role_to_mesh = pipeline.create_role_to_mesh()
796+
797+
self.assertSequenceEqual(
798+
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(),
799+
[0, 1],
800+
)
801+
self.assertSequenceEqual(
802+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(),
803+
[0, 1],
804+
)
805+
self.assertEqual(
806+
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.shape,
807+
(2, 1),
808+
)
809+
self.assertEqual(
810+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape,
811+
(1, 2),
812+
)
813+
735814

736815
if __name__ == "__main__":
737816
absltest.main()

tests/rl/agentic/agentic_grpo_learner_test.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(self, algo_config):
224224
self.algo_config = algo_config
225225
self.rl_cluster = mock.Mock()
226226
self.metric_fns = []
227+
self._train_micro_batch_size = 1
227228

228229
def _create_micro_batch_iterator(self, iterator, batch_size):
229230
# The dataset batch size is 2, and we want to test micro-batching
@@ -296,11 +297,86 @@ async def _orchestrator_producer(
296297
item = train_data_queue.get(block=True)
297298
if item is None:
298299
break
299-
results.append(item)
300+
results.extend(item)
300301

301302
prompt_ids = [r.prompt_ids[0] for r in results]
302303
self.assertEqual(prompt_ids, [0, 0, 0, 0, 1, 1, 1, 1])
303304

305+
def test_iterator_colocated_batches_full_rollout_batch(self):
306+
class _MockTrainer(agentic_grpo_learner.GRPOLearner):
307+
308+
def __init__(self, algo_config):
309+
self.algo_config = algo_config
310+
self.rl_cluster = mock.Mock()
311+
self.metric_fns = []
312+
self.can_enable_async_rollout = False
313+
self._share_actor_rollout_devices = True
314+
self._full_batch_size = 2
315+
self._train_micro_batch_size = 2
316+
317+
def _create_micro_batch_iterator(self, iterator, batch_size):
318+
del batch_size
319+
for batch in iterator:
320+
for i in range(len(batch["prompts"])):
321+
yield jax.tree.map(lambda x, index=i: x[index : index + 1], batch)
322+
323+
@override
324+
def _batch_to_train_example(self, batch_results, mode):
325+
del mode
326+
examples = []
327+
for _ in range(self.algo_config.num_generations):
328+
examples.append(
329+
types.SimpleNamespace(
330+
prompt_ids=batch_results[1][0]["prompts"],
331+
)
332+
)
333+
return examples
334+
335+
@override
336+
async def _orchestrator_producer(
337+
self,
338+
orchestrator,
339+
prompt_iterator: Iterable[TrainingInputT] | AsyncIterable[TrainingInputT],
340+
num_generations: int = 1,
341+
collect_mode: str = "Token",
342+
):
343+
del orchestrator, num_generations, collect_mode
344+
i = 0
345+
async for example in prompt_iterator:
346+
group = [
347+
types.SimpleNamespace(pair_index=i * 2 + j) for j in range(2)
348+
]
349+
yield group, [example]
350+
i += 1
351+
352+
algo_config = agentic_grpo_learner.GRPOConfig(
353+
num_generations=2,
354+
num_iterations=2,
355+
)
356+
trainer = _MockTrainer(algo_config)
357+
358+
train_data_queue = queue_lib.SimpleDataQueue(maxsize=0)
359+
dataset = _dummy_dataset(MySource(data=[i for i in range(2)]), batch_size=2)
360+
prompt_queue = queue.Queue()
361+
for item in iter(dataset):
362+
prompt_queue.put(item)
363+
prompt_queue.put(None)
364+
365+
asyncio.run(trainer._producer(mock.Mock(), prompt_queue, train_data_queue))
366+
367+
queue_items = []
368+
while True:
369+
item = train_data_queue.get(block=True)
370+
if item is None:
371+
break
372+
queue_items.append(item)
373+
374+
self.assertLen(queue_items, 4)
375+
for batch in queue_items:
376+
self.assertLen(batch, 2)
377+
prompt_ids = [r.prompt_ids[0] for batch in queue_items for r in batch]
378+
self.assertEqual(prompt_ids, [0, 0, 0, 0, 1, 1, 1, 1])
379+
304380
def test_grpo_config_validation(self):
305381
with self.assertRaisesRegex(
306382
ValueError, "num_generations must be greater than 1"
@@ -636,7 +712,7 @@ def mock_compute_rewards(prompts, completions, **kwargs):
636712
algo_config=grpo_config,
637713
chat_parser=MockChatParser(),
638714
)
639-
715+
640716
with mock.patch.object(learner, "_compute_rewards", side_effect=mock_compute_rewards):
641717
with mock.patch.object(
642718
learner.rl_cluster,
@@ -645,7 +721,7 @@ def mock_compute_rewards(prompts, completions, **kwargs):
645721
autospec=True,
646722
):
647723
learner._process_results(trajectories)
648-
724+
649725
self.assertEqual(extracted_completions, ["msg 0", "msg 1"])
650726

651727
@parameterized.named_parameters(

tunix/cli/grpo_main.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class GrpoPipeline(config.HyperParameters):
7373
plus ``max_turns``, ``context_ratio``, ``per_turn_timeout_secs``.
7474
* role-specific ``*_model_config.mesh``: any role with an explicit mesh gets
7575
its own device slice; omitted meshes share the actor mesh by default.
76-
* role-specific ``same_mesh_as``: optional mesh sharing like
77-
``reference_model_config.same_mesh_as: actor``.
76+
* role-specific ``colocate_with``: share another role's device set while
77+
still allowing a different mesh shape on that same device set.
7878
* ``sglang_jax_config`` / ``vllm_config``: engine-specific rollout params.
7979
* ``chat_parser_config.type``: ``"default"`` or ``"qwen"``.
8080
* ``agent_class_path`` / ``env_class_path``: dotted Python paths to load
@@ -116,21 +116,19 @@ def _resolve_split_role(self, role_name: str) -> rl_cluster_lib.Role:
116116
)
117117
return self._SPLIT_ROLE_ALIASES[normalized]
118118

119-
def _get_same_mesh_as_map(
119+
def _get_colocate_with_map(
120120
self,
121121
) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]:
122-
same_mesh_as = {}
122+
colocate_with = {}
123123
for role, model_key in self._ROLE_TO_MODEL_KEY.items():
124124
model_cfg = self.config.get(model_key, {}) or {}
125-
target_name = model_cfg.get("same_mesh_as")
125+
target_name = model_cfg.get("colocate_with")
126126
if target_name is None:
127127
continue
128-
target_role = self._resolve_split_role(str(target_name))
129128
if role == rl_cluster_lib.Role.ACTOR:
130-
raise ValueError("Actor must own its mesh.")
131-
same_mesh_as[role] = target_role
132-
133-
return same_mesh_as
129+
raise ValueError("Actor must own its device set.")
130+
colocate_with[role] = self._resolve_split_role(str(target_name))
131+
return colocate_with
134132

135133
def _is_role_active(self, role: rl_cluster_lib.Role) -> bool:
136134
if role in (
@@ -145,10 +143,10 @@ def _is_role_active(self, role: rl_cluster_lib.Role) -> bool:
145143
def _resolve_mesh_owners(
146144
self,
147145
) -> dict[rl_cluster_lib.Role, rl_cluster_lib.Role]:
148-
same_mesh_as = self._get_same_mesh_as_map()
146+
colocate_with = self._get_colocate_with_map()
149147
base_owners = {}
150148
for role, model_key in self._ROLE_TO_MODEL_KEY.items():
151-
if not self._is_role_active(role) and role not in same_mesh_as:
149+
if not self._is_role_active(role):
152150
continue
153151
has_mesh = bool(self.config.get(model_key, {}).get("mesh"))
154152
base_owners[role] = (
@@ -162,35 +160,28 @@ def resolve_owner(
162160
seen: set[rl_cluster_lib.Role],
163161
) -> rl_cluster_lib.Role:
164162
if role in seen:
165-
raise ValueError("same_mesh_as contains a cycle.")
166-
if role not in same_mesh_as:
163+
raise ValueError("colocate_with contains a cycle.")
164+
if role not in colocate_with:
167165
return base_owners[role]
168166
seen.add(role)
169-
target_role = same_mesh_as[role]
167+
target_role = colocate_with[role]
170168
if target_role not in base_owners:
171169
raise ValueError(
172170
f"Role {target_role.value!r} is not active in this config."
173171
)
174172
return resolve_owner(target_role, seen)
175173

176174
role_to_owner = {}
177-
for role, model_key in self._ROLE_TO_MODEL_KEY.items():
178-
if role not in base_owners:
179-
continue
180-
has_mesh = bool(self.config.get(model_key, {}).get("mesh"))
181-
if role in same_mesh_as:
182-
if has_mesh:
183-
raise ValueError(
184-
f"{model_key}.mesh is specified, so it must own a separate mesh "
185-
"and cannot also use same_mesh_as."
186-
)
187-
else:
188-
role_to_owner[role] = resolve_owner(role, set())
189-
continue
175+
for role in base_owners:
190176
role_to_owner[role] = resolve_owner(role, set())
191177
return role_to_owner
192178

193-
def _create_role_to_mesh(self):
179+
def create_role_to_mesh(self):
180+
"""Build role→mesh mapping.
181+
182+
Any role with an explicit ``*.mesh`` config gets a dedicated device slice.
183+
Roles without a mesh share the actor mesh by default.
184+
"""
194185
devices = list(jax.devices())
195186
role_to_owner = self._resolve_mesh_owners()
196187
owner_order = []
@@ -235,16 +226,18 @@ def _create_role_to_mesh(self):
235226
for owner in owner_order
236227
},
237228
)
238-
return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()}
239-
240-
def create_role_to_mesh(self):
241-
"""Build role→mesh mapping.
229+
role_to_mesh = {}
230+
for role, owner in role_to_owner.items():
231+
model_key = self._ROLE_TO_MODEL_KEY[role]
232+
has_mesh = bool(self.config.get(model_key, {}).get("mesh"))
233+
if role == owner or not has_mesh:
234+
role_to_mesh[role] = owner_to_mesh[owner]
235+
else:
236+
role_to_mesh[role] = self.create_mesh(
237+
model_key, devices=owner_to_device_slice[owner]
238+
)
239+
return role_to_mesh
242240

243-
Any role with an explicit ``*.mesh`` config gets a dedicated device slice.
244-
Roles without a mesh share the actor mesh by default, or can point at
245-
another role via ``same_mesh_as``.
246-
"""
247-
return self._create_role_to_mesh()
248241

249242
# ------------------------------------------------------------------
250243
# Rollout config

0 commit comments

Comments
 (0)