Skip to content

Commit 227476b

Browse files
committed
Add colocated mode to agentic cli.
1 parent 37449fe commit 227476b

10 files changed

Lines changed: 1028 additions & 287 deletions

File tree

examples/deepscaler/run_deepscaler_disagg_v5p16.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ python -m tunix.cli.grpo_main \
6464
model_config.remat_config=3 \
6565
actor_model_config.mesh.shape="$trainer_mesh" \
6666
actor_model_config.mesh.axis_names="('fsdp','tp')" \
67-
reference_model_config.mesh=null \
68-
reference_model_config.same_mesh_as="actor" \
6967
rollout_model_config.mesh.shape="$rollout_mesh" \
7068
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
7169
\

examples/deepswe/run_deepswe_disagg_v5p_32.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ python -m tunix.cli.grpo_main \
8181
model_config.remat_config=3 \
8282
actor_model_config.mesh.shape="$trainer_mesh" \
8383
actor_model_config.mesh.axis_names="('fsdp','tp')" \
84-
reference_model_config.mesh=null \
85-
reference_model_config.same_mesh_as="actor" \
8684
rollout_model_config.mesh.shape="$rollout_mesh" \
8785
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
8886
\

examples/rl/grpo/gsm8k/run_qwen3_8b.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ num_generations="${num_generations:-4}"
4545
train_mesh="${train_mesh:-(8,1)}"
4646
rollout_mesh="${rollout_mesh:-(1,8)}"
4747

48-
checkpoint_dir="${checkpoint_dir:-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}"
48+
# Set rollout_colocate to the mesh name (e.g. "actor") to colocate the rollout
49+
# model on the same mesh as the actor model
50+
rollout_colocate="${rollout_colocate:-null}"
51+
52+
checkpoint_dir="${checkpoint_dir-gs://tunix/rl/checkpoints/gsm8k/qwen3/01}"
4953
checkpoint_suffix="${checkpoint_suffix:-$(printf '%04d' "$((RANDOM % 10000))")}"
5054
if [[ -n "$checkpoint_dir" && "$checkpoint_dir" != "null" ]]; then
5155
checkpoint_dir="${checkpoint_dir}_${checkpoint_suffix}"
@@ -79,8 +83,7 @@ python -m tunix.cli.grpo_main \
7983
model_config.remat_config=3 \
8084
actor_model_config.mesh.shape="$train_mesh" \
8185
actor_model_config.mesh.axis_names="('fsdp','tp')" \
82-
reference_model_config.mesh=null \
83-
reference_model_config.same_mesh_as="actor" \
86+
rollout_model_config.colocate_with="$rollout_colocate" \
8487
rollout_model_config.mesh.shape="$rollout_mesh" \
8588
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
8689
\

tests/cli/grpo_main_test.py

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,6 @@ def test_cli_empty_system_prompt_stays_empty_string(self):
644644
)
645645
self.assertEqual(p.config["agentic_grpo_config"]["system_prompt"], "")
646646

647-
648647
class SplitMeshConfigTest(absltest.TestCase):
649648

650649
def test_split_mesh_uses_explicit_role_meshes(self):
@@ -688,7 +687,6 @@ def test_split_mesh_uses_explicit_role_meshes(self):
688687
"shape": "(2,1)",
689688
"axis_names": "('fsdp','tp')",
690689
}
691-
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
692690
rollout_model_config = pipeline.config["rollout_model_config"]
693691
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
694692
rollout_model_config["mesh"] = {
@@ -732,6 +730,128 @@ def __init__(self, devices, axis_names, axis_types=None):
732730
role_to_mesh[rl_cluster_lib.Role.ACTOR],
733731
)
734732

733+
def test_colocate_with_reuses_device_slice_with_different_mesh(self):
734+
extra = """
735+
training_mode: "agentic_grpo"
736+
data_module: "tunix.cli.recipes.deepscaler_data"
737+
apply_chat_template_to_dataset: false
738+
data_config:
739+
train_data_path: "gs://fake/train.json"
740+
eval_data_path: "gs://fake/eval.parquet"
741+
prompt_key: "prompts"
742+
reward_functions: []
743+
verl_compatible: false
744+
chat_parser_config:
745+
type: "default"
746+
agent_class_path: null
747+
agent_kwargs: {}
748+
env_class_path: null
749+
env_kwargs: {}
750+
kubernetes_config: null
751+
agentic_grpo_config:
752+
num_generations: 2
753+
num_iterations: 1
754+
beta: 0.0
755+
epsilon: 0.2
756+
epsilon_high: 0.28
757+
system_prompt: ""
758+
max_concurrency: 1
759+
off_policy_steps: 0
760+
max_turns: 1
761+
context_ratio: 1
762+
sglang_jax_config:
763+
mem_fraction_static: 0.8
764+
vllm_config:
765+
hbm_utilization: 0.4
766+
"""
767+
pipeline = _make_pipeline(extra)
768+
actor_model_config = pipeline.config["actor_model_config"]
769+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
770+
actor_model_config["mesh"] = {
771+
"shape": "(2,1)",
772+
"axis_names": "('fsdp','tp')",
773+
}
774+
rollout_model_config = pipeline.config["rollout_model_config"]
775+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
776+
rollout_model_config["colocate_with"] = "actor"
777+
rollout_model_config["mesh"] = {
778+
"shape": "(1,2)",
779+
"axis_names": "('fsdp','tp')",
780+
}
781+
782+
fake_devices = list(range(4))
783+
784+
class FakeMesh:
785+
786+
def __init__(self, devices, axis_names, axis_types=None):
787+
self.devices = devices
788+
self.axis_names = axis_names
789+
self.axis_types = axis_types
790+
791+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
792+
with mock.patch.object(
793+
grpo_main.jax.sharding, "Mesh", side_effect=FakeMesh
794+
):
795+
role_to_mesh = pipeline.create_role_to_mesh()
796+
797+
self.assertSequenceEqual(
798+
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(),
799+
[0, 1],
800+
)
801+
self.assertSequenceEqual(
802+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(),
803+
[0, 1],
804+
)
805+
self.assertEqual(
806+
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.shape,
807+
(2, 1),
808+
)
809+
self.assertEqual(
810+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.shape,
811+
(1, 2),
812+
)
813+
814+
def test_empty_string_colocate_with_is_treated_as_unset(self):
815+
extra = """
816+
training_mode: "agentic_grpo"
817+
data_module: "tunix.cli.recipes.deepscaler_data"
818+
apply_chat_template_to_dataset: false
819+
data_config:
820+
train_data_path: "gs://fake/train.json"
821+
eval_data_path: "gs://fake/eval.parquet"
822+
prompt_key: "prompts"
823+
reward_functions: []
824+
verl_compatible: false
825+
chat_parser_config:
826+
type: "default"
827+
agent_class_path: null
828+
agent_kwargs: {}
829+
env_class_path: null
830+
env_kwargs: {}
831+
kubernetes_config: null
832+
agentic_grpo_config:
833+
num_generations: 2
834+
num_iterations: 1
835+
beta: 0.0
836+
epsilon: 0.2
837+
epsilon_high: 0.28
838+
system_prompt: ""
839+
max_concurrency: 1
840+
off_policy_steps: 0
841+
max_turns: 1
842+
context_ratio: 1
843+
sglang_jax_config:
844+
mem_fraction_static: 0.8
845+
vllm_config:
846+
hbm_utilization: 0.4
847+
"""
848+
pipeline = _make_pipeline(extra)
849+
rollout_model_config = pipeline.config["rollout_model_config"]
850+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
851+
rollout_model_config["colocate_with"] = ""
852+
853+
self.assertEmpty(pipeline._get_colocate_with_map())
854+
735855

736856
if __name__ == "__main__":
737857
absltest.main()

0 commit comments

Comments
 (0)