Skip to content

Commit 4f9df7b

Browse files
committed
Add colocated mode to agentic cli.
1 parent 37449fe commit 4f9df7b

14 files changed

Lines changed: 1133 additions & 291 deletions

docs/agentic_rl.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,53 @@ generating trajectories with stale parameters.
120120
<img src="images/batch_vs_async_rollout.png" alt="Batch vs Async Rollout" width="50%">
121121
</p>
122122

123+
### Mesh Placement and Colocation
124+
125+
Agentic GRPO supports three distinct placement patterns for the actor,
126+
reference, and rollout roles:
127+
128+
1. **Shared mesh**: multiple roles reuse the exact same mesh object. This is
129+
the most tightly colocated setup and may enable model or backbone sharing
130+
in parts of the stack.
131+
132+
Backend status: exact shared-mesh support is currently supported for the
133+
vanilla rollout backend. Exact shared-mesh execution is not supported yet
134+
for `vllm` and `sglang_jax`.
135+
136+
2. **Colocated device set**: a role uses `colocate_with` to reuse another
137+
role's device slice while still keeping its own mesh shape. For example,
138+
the actor can use `(4, 1)` while rollout uses `(1, 4)` on the same four
139+
devices. This is still colocated, but it is different from exact mesh
140+
sharing.
141+
142+
3. **Disaggregated placement**: each role owns a separate device slice.
143+
144+
For CLI-driven GRPO and agentic GRPO runs, `colocate_with` is configured on the
145+
role-specific model config, for example:
146+
147+
```yaml
148+
actor_model_config:
149+
mesh:
150+
shape: "(4,1)"
151+
axis_names: "('fsdp','tp')"
152+
153+
rollout_model_config:
154+
colocate_with: "actor"
155+
mesh:
156+
shape: "(1,4)"
157+
axis_names: "('fsdp','tp')"
158+
```
159+
160+
This tells Tunix to allocate only one four-device owner slice and build two
161+
different meshes on top of it. Exact mesh equality is only required when a
162+
runtime path wants to reuse the same model instance; it is not required for
163+
colocation itself.
164+
165+
For accelerated rollout backends, treat same-device-set colocation and exact
166+
shared-mesh reuse as different features. Today, `vllm` and `sglang_jax`
167+
support colocated device-set placement through `colocate_with`, but exact
168+
shared-mesh execution is not supported yet.
169+
123170
### Trajectory Batching and Grouping
124171

125172
Tunix supports batching of agentic trajectories through the `GroupQueueManager`.

docs/launching.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ This section provides a detailed explanation of the configuration parameters ava
254254

255255
#### Model Configuration (`model_config`)
256256

257-
These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration.
257+
These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration.
258258

259259
* **`model_name`**: The unique full name identifier of the model. This
260260
corresponds to the full name and should match exactly with the model name
@@ -287,6 +287,16 @@ These parameters define the base model, where to download it from, and how to sh
287287
* **`mesh`**: Defines the hardware mesh layout for distributed training.
288288
* `shape`: Tuple string defining mesh dimensions (e.g., `"(2,2)"` for a 2x2 grid).
289289
* `axis_names`: Names for mesh axes, often used for parallelism strategies (e.g., `"('fsdp','tp')"` for Fully Sharded Data Parallelism and Tensor Parallelism).
290+
* **`colocate_with`**: Optional role-local placement override for
291+
`actor_model_config`, `reference_model_config`, `rollout_model_config`, and
292+
other RL roles.
293+
* If unset, a role owns its own device slice when it has an explicit
294+
`mesh`, or shares the actor mesh by default when it does not.
295+
* If set to a role name such as `"actor"`, the role reuses that role's
296+
device slice but may still define its own `mesh.shape` and
297+
`mesh.axis_names`.
298+
* This is different from exact mesh sharing: two roles can be colocated on
299+
the same devices while using different mesh layouts.
290300
291301
292302
#### Tokenizer Configuration (`tokenizer_config`)
@@ -338,7 +348,7 @@ General settings for the training loop, logging, and checkpointing.
338348
339349
* **`eval_every_n_steps`**: Frequency of running evaluation steps.
340350
341-
* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients
351+
* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients
342352
before performing a parameter update (simulates larger batch sizes).
343353
344354
* **`checkpointing_options`**:

docs/performance.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,13 @@ and training. To further maximize the hardware utility, you can consider enablin
156156
non-active models to CPU RAM when a
157157
different component is occupying the TPU.
158158

159-
Enabling collocated mode is straightforward; you simply provide the same mesh to
160-
every component when configuring the `role_to_mesh` mapping for your `rl_cluster`.
159+
Enabling collocated mode is straightforward; the strongest form is to provide
160+
the same mesh to every component when configuring the `role_to_mesh` mapping
161+
for your `rl_cluster`.
162+
163+
Backend status: exact shared-mesh execution is currently supported for the
164+
vanilla rollout backend. Exact shared-mesh execution is not supported yet for
165+
`vllm` and `sglang-jax`.
161166

162167
```python
163168
import numpy as np
@@ -179,6 +184,35 @@ ClusterConfig(
179184
)
180185
```
181186

187+
For CLI-driven GRPO and agentic GRPO runs, there is now a second colocated
188+
mode: **same device set, different mesh shape**. This is configured with
189+
`colocate_with`.
190+
191+
```yaml
192+
actor_model_config:
193+
mesh:
194+
shape: "(4,1)"
195+
axis_names: "('fsdp','tp')"
196+
197+
rollout_model_config:
198+
colocate_with: "actor"
199+
mesh:
200+
shape: "(1,4)"
201+
axis_names: "('fsdp','tp')"
202+
```
203+
204+
In this configuration, actor and rollout are still colocated because they run
205+
on the same device slice, but they do not share the exact same mesh object.
206+
That distinction matters:
207+
208+
* Same device set means the roles are colocated.
209+
* Same mesh may additionally allow model or backbone sharing in some runtime
210+
paths.
211+
212+
For `vllm` and `sglang-jax`, this same-device-set colocation mode is the
213+
currently supported colocated placement. Exact shared-mesh reuse is not
214+
supported yet for those backends.
215+
182216
### Disaggregated Execution
183217

184218
Disaggregated mode partitions the TPU cluster into distinct "sub-meshes",

docs/rollout.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ Setting `cluster_config.rollout_engine="vllm"` enables the vllm rollout/sampler.
9191
Tunix uses `tunix.rl.rollout.base_rollout.RolloutConfig` for rollout settings.
9292
The fields below are the vLLM-relevant ones.
9393

94+
Exact shared-mesh execution is currently supported only for the vanilla rollout
95+
backend. For `vllm`, exact shared-mesh execution is not supported yet. The
96+
supported colocated configuration today is same-device-set placement with an
97+
independently shaped rollout mesh.
98+
9499
#### vLLM-specific fields
95100

96101
In addition to the common sampling parameters mentioned above, the following
@@ -277,6 +282,11 @@ Tunix uses `tunix.rl.rollout.base_rollout.RolloutConfig` for rollout settings.
277282
In addition to the common sampling parameters, the following fields are specific
278283
to SGLang-Jax:
279284

285+
Exact shared-mesh execution is currently supported only for the vanilla rollout
286+
backend. For `sglang_jax`, exact shared-mesh execution is not supported yet.
287+
The supported colocated configuration today is same-device-set placement with
288+
an independently shaped rollout mesh.
289+
280290
- `rollout_sglang_jax_model_version`
281291

282292
- Model id or local path used by SGLang-Jax as `model_path`.

examples/deepscaler/run_deepscaler_disagg_v5p16.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ python -m tunix.cli.grpo_main \
6464
model_config.remat_config=3 \
6565
actor_model_config.mesh.shape="$trainer_mesh" \
6666
actor_model_config.mesh.axis_names="('fsdp','tp')" \
67-
reference_model_config.mesh=null \
68-
reference_model_config.same_mesh_as="actor" \
6967
rollout_model_config.mesh.shape="$rollout_mesh" \
7068
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
7169
\

examples/deepswe/run_deepswe_disagg_v5p_32.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ python -m tunix.cli.grpo_main \
8181
model_config.remat_config=3 \
8282
actor_model_config.mesh.shape="$trainer_mesh" \
8383
actor_model_config.mesh.axis_names="('fsdp','tp')" \
84-
reference_model_config.mesh=null \
85-
reference_model_config.same_mesh_as="actor" \
8684
rollout_model_config.mesh.shape="$rollout_mesh" \
8785
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
8886
\

examples/rl/grpo/gsm8k/run_qwen3_8b.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ num_generations="${num_generations:-4}"
4545
train_mesh="${train_mesh:-(8,1)}"
4646
rollout_mesh="${rollout_mesh:-(1,8)}"
4747

48-
checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}"
48+
# Set rollout_colocate to the mesh name (e.g. "actor") to colocate the rollout
49+
# model on the same mesh as the actor model
50+
rollout_colocate="${rollout_colocate:-null}"
51+
52+
checkpoint_dir="${checkpoint_dir-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}"
4953
checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}"
5054
if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then
5155
checkpoint_dir="${checkpoint_dir}_${checkpoint_suffix}"
@@ -79,8 +83,7 @@ python -m tunix.cli.grpo_main \
7983
model_config.remat_config=3 \
8084
actor_model_config.mesh.shape="$train_mesh" \
8185
actor_model_config.mesh.axis_names="('fsdp','tp')" \
82-
reference_model_config.mesh=null \
83-
reference_model_config.same_mesh_as="actor" \
86+
rollout_model_config.colocate_with="$rollout_colocate" \
8487
rollout_model_config.mesh.shape="$rollout_mesh" \
8588
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
8689
\

tests/cli/grpo_main_test.py

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,6 @@ def test_cli_empty_system_prompt_stays_empty_string(self):
644644
)
645645
self.assertEqual(p.config["agentic_grpo_config"]["system_prompt"], "")
646646

647-
648647
class SplitMeshConfigTest(absltest.TestCase):
649648

650649
def test_split_mesh_uses_explicit_role_meshes(self):
@@ -688,7 +687,6 @@ def test_split_mesh_uses_explicit_role_meshes(self):
688687
"shape": "(2,1)",
689688
"axis_names": "('fsdp','tp')",
690689
}
691-
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
692690
rollout_model_config = pipeline.config["rollout_model_config"]
693691
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
694692
rollout_model_config["mesh"] = {
@@ -732,6 +730,128 @@ def __init__(self, devices, axis_names, axis_types=None):
732730
role_to_mesh[rl_cluster_lib.Role.ACTOR],
733731
)
734732

733+
def test_colocate_with_reuses_device_slice_with_different_mesh(self):
734+
extra = """
735+
training_mode: "agentic_grpo"
736+
data_module: "tunix.cli.recipes.deepscaler_data"
737+
apply_chat_template_to_dataset: false
738+
data_config:
739+
train_data_path: "gs://fake/train.json"
740+
eval_data_path: "gs://fake/eval.parquet"
741+
prompt_key: "prompts"
742+
reward_functions: []
743+
verl_compatible: false
744+
chat_parser_config:
745+
type: "default"
746+
agent_class_path: null
747+
agent_kwargs: {}
748+
env_class_path: null
749+
env_kwargs: {}
750+
kubernetes_config: null
751+
agentic_grpo_config:
752+
num_generations: 2
753+
num_iterations: 1
754+
beta: 0.0
755+
epsilon: 0.2
756+
epsilon_high: 0.28
757+
system_prompt: ""
758+
max_concurrency: 1
759+
off_policy_steps: 0
760+
max_turns: 1
761+
context_ratio: 1
762+
sglang_jax_config:
763+
mem_fraction_static: 0.8
764+
vllm_config:
765+
hbm_utilization: 0.4
766+
"""
767+
pipeline = _make_pipeline(extra)
768+
actor_model_config = pipeline.config["actor_model_config"]
769+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
770+
actor_model_config["mesh"] = {
771+
"shape": "(2,1)",
772+
"axis_names": "('fsdp','tp')",
773+
}
774+
rollout_model_config = pipeline.config["rollout_model_config"]
775+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
776+
rollout_model_config["colocate_with"] = "actor"
777+
rollout_model_config["mesh"] = {
778+
"shape": "(1,2)",
779+
"axis_names": "('fsdp','tp')",
780+
}
781+
782+
fake_devices = list(range(4))
783+
784+
class FakeMesh:
785+
786+
def __init__(self, devices, axis_names, axis_types=None):
787+
self.devices = devices
788+
self.axis_names = axis_names
789+
self.axis_types = axis_types
790+
791+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
792+
with mock.patch.object(
793+
grpo_main.jax.sharding, "Mesh", side_effect=FakeMesh
794+
):
795+
role_to_mesh = pipeline.create_role_to_mesh()
796+
797+
self.assertSequenceEqual(
798+
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(),
799+
[0, 1],
800+
)
801+
self.assertSequenceEqual(
802+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(),
803+
[0, 1],
804+
)
805+
self.assertEqual(
806+
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.shape,
807+
(2, 1),
808+
)
809+
self.assertEqual(
810+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape,
811+
(1, 2),
812+
)
813+
814+
def test_empty_string_colocate_with_is_treated_as_unset(self):
815+
extra = """
816+
training_mode: "agentic_grpo"
817+
data_module: "tunix.cli.recipes.deepscaler_data"
818+
apply_chat_template_to_dataset: false
819+
data_config:
820+
train_data_path: "gs://fake/train.json"
821+
eval_data_path: "gs://fake/eval.parquet"
822+
prompt_key: "prompts"
823+
reward_functions: []
824+
verl_compatible: false
825+
chat_parser_config:
826+
type: "default"
827+
agent_class_path: null
828+
agent_kwargs: {}
829+
env_class_path: null
830+
env_kwargs: {}
831+
kubernetes_config: null
832+
agentic_grpo_config:
833+
num_generations: 2
834+
num_iterations: 1
835+
beta: 0.0
836+
epsilon: 0.2
837+
epsilon_high: 0.28
838+
system_prompt: ""
839+
max_concurrency: 1
840+
off_policy_steps: 0
841+
max_turns: 1
842+
context_ratio: 1
843+
sglang_jax_config:
844+
mem_fraction_static: 0.8
845+
vllm_config:
846+
hbm_utilization: 0.4
847+
"""
848+
pipeline = _make_pipeline(extra)
849+
rollout_model_config = pipeline.config["rollout_model_config"]
850+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
851+
rollout_model_config["colocate_with"] = ""
852+
853+
self.assertEmpty(pipeline._get_colocate_with_map())
854+
735855

736856
if __name__ == "__main__":
737857
absltest.main()

0 commit comments

Comments
 (0)