Skip to content

Commit eea2c92

Browse files
committed
Fix mesh creation.
1 parent 35aeaf4 commit eea2c92

8 files changed

Lines changed: 1439 additions & 31 deletions

File tree

.github/workflows/cpu-tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ on:
2222

2323
permissions:
2424
contents: read
25-
jobs:
25+
jobs:
2626
run:
2727
runs-on: ubuntu-latest
2828
steps:
@@ -63,6 +63,10 @@ jobs:
6363
run: |
6464
python -m pytest tests/cli/utils/ -v --tb=short
6565
66+
- name: Run shared mesh and topology tests
67+
run: |
68+
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short
69+
6670
- name: Run perf tests
6771
run: |
6872
python -m pytest tests/perf/ -v --tb=short

.github/workflows/tpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
- name: Run tunix tests not covered by the above categories
121121
run: |
122122
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
123-
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
123+
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
124124
if [ "${code:-0}" = "5" ]; then
125125
echo "No tests collected (expected)."
126126
exit 0

tests/cli/grpo_main_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,97 @@ def __init__(self, devices, axis_names, axis_types=None):
732732
role_to_mesh[rl_cluster_lib.Role.ACTOR],
733733
)
734734

735+
def test_split_mesh_delegates_device_allocation_to_mesh_utils(self):
736+
extra = """
737+
training_mode: "agentic_grpo"
738+
data_module: "tunix.cli.recipes.deepscaler_data"
739+
apply_chat_template_to_dataset: false
740+
data_config:
741+
train_data_path: "gs://fake/train.json"
742+
eval_data_path: "gs://fake/eval.parquet"
743+
prompt_key: "prompts"
744+
reward_functions: []
745+
verl_compatible: false
746+
chat_parser_config:
747+
type: "default"
748+
agent_class_path: null
749+
agent_kwargs: {}
750+
env_class_path: null
751+
env_kwargs: {}
752+
kubernetes_config: null
753+
agentic_grpo_config:
754+
num_generations: 2
755+
num_iterations: 1
756+
beta: 0.0
757+
epsilon: 0.2
758+
epsilon_high: 0.28
759+
system_prompt: ""
760+
max_concurrency: 1
761+
off_policy_steps: 0
762+
max_turns: 1
763+
context_ratio: 1
764+
sglang_jax_config:
765+
mem_fraction_static: 0.8
766+
vllm_config:
767+
hbm_utilization: 0.4
768+
"""
769+
pipeline = _make_pipeline(extra)
770+
actor_model_config = pipeline.config["actor_model_config"]
771+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
772+
actor_model_config["mesh"] = {
773+
"shape": "(1,2)",
774+
"axis_names": "('fsdp','tp')",
775+
}
776+
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
777+
rollout_model_config = pipeline.config["rollout_model_config"]
778+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
779+
rollout_model_config["mesh"] = {
780+
"shape": "(1,2)",
781+
"axis_names": "('fsdp','tp')",
782+
}
783+
784+
fake_devices = ["a0", "a1", "r0", "r1"]
785+
allocated_devices = {
786+
"actor_model_config": ["a0", "a1"],
787+
"rollout_model_config": ["r0", "r1"],
788+
}
789+
created_mesh_devices = {}
790+
791+
def fake_create_mesh(model_key, devices=None):
792+
created_mesh_devices[model_key] = list(devices)
793+
return (model_key, tuple(devices))
794+
795+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
796+
with mock.patch.object(
797+
grpo_main.mesh_lib,
798+
"allocate_named_mesh_device_slices",
799+
return_value=allocated_devices,
800+
) as allocate_mock:
801+
with mock.patch.object(pipeline, "create_mesh", side_effect=fake_create_mesh):
802+
role_to_mesh = pipeline.create_role_to_mesh()
803+
804+
allocate_mock.assert_called_once_with(
805+
[
806+
("actor_model_config", 2),
807+
("rollout_model_config", 2),
808+
],
809+
devices=fake_devices,
810+
)
811+
self.assertEqual(created_mesh_devices["actor_model_config"], ["a0", "a1"])
812+
self.assertEqual(created_mesh_devices["rollout_model_config"], ["r0", "r1"])
813+
self.assertEqual(
814+
role_to_mesh[rl_cluster_lib.Role.ACTOR],
815+
("actor_model_config", ("a0", "a1")),
816+
)
817+
self.assertIs(
818+
role_to_mesh[rl_cluster_lib.Role.REFERENCE],
819+
role_to_mesh[rl_cluster_lib.Role.ACTOR],
820+
)
821+
self.assertEqual(
822+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT],
823+
("rollout_model_config", ("r0", "r1")),
824+
)
825+
735826

736827
if __name__ == "__main__":
737828
absltest.main()

0 commit comments

Comments
 (0)