Skip to content

Commit f0d3dc8

Browse files
committed
Fix mesh creation.
1 parent 3ce9765 commit f0d3dc8

8 files changed

Lines changed: 1925 additions & 30 deletions

File tree

.github/workflows/cpu-tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ on:
2222

2323
permissions:
2424
contents: read
25-
jobs:
25+
jobs:
2626
run:
2727
runs-on: ubuntu-latest
2828
steps:
@@ -63,6 +63,10 @@ jobs:
6363
run: |
6464
python -m pytest tests/cli/utils/ -v --tb=short
6565
66+
- name: Run shared mesh and topology tests
67+
run: |
68+
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short
69+
6670
- name: Run perf tests
6771
run: |
6872
python -m pytest tests/perf/ -v --tb=short

.github/workflows/tpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
- name: Run tunix tests not covered by the above categories
121121
run: |
122122
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
123-
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
123+
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
124124
if [ "${code:-0}" = "5" ]; then
125125
echo "No tests collected (expected)."
126126
exit 0

tests/cli/grpo_main_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,97 @@ def __init__(self, devices, axis_names, axis_types=None):
753753
role_to_mesh[rl_cluster_lib.Role.ACTOR],
754754
)
755755

756+
def test_split_mesh_delegates_device_allocation_to_mesh_utils(self):
757+
extra = """
758+
training_mode: "agentic_grpo"
759+
data_module: "tunix.cli.recipes.deepscaler_data"
760+
apply_chat_template_to_dataset: false
761+
data_config:
762+
train_data_path: "gs://fake/train.json"
763+
eval_data_path: "gs://fake/eval.parquet"
764+
prompt_key: "prompts"
765+
reward_functions: []
766+
verl_compatible: false
767+
chat_parser_config:
768+
type: "default"
769+
agent_class_path: null
770+
agent_kwargs: {}
771+
env_class_path: null
772+
env_kwargs: {}
773+
kubernetes_config: null
774+
agentic_grpo_config:
775+
num_generations: 2
776+
num_iterations: 1
777+
beta: 0.0
778+
epsilon: 0.2
779+
epsilon_high: 0.28
780+
system_prompt: ""
781+
max_concurrency: 1
782+
off_policy_steps: 0
783+
max_turns: 1
784+
context_ratio: 1
785+
sglang_jax_config:
786+
mem_fraction_static: 0.8
787+
vllm_config:
788+
hbm_utilization: 0.4
789+
"""
790+
pipeline = _make_pipeline(extra)
791+
actor_model_config = pipeline.config["actor_model_config"]
792+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
793+
actor_model_config["mesh"] = {
794+
"shape": "(1,2)",
795+
"axis_names": "('fsdp','tp')",
796+
}
797+
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
798+
rollout_model_config = pipeline.config["rollout_model_config"]
799+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
800+
rollout_model_config["mesh"] = {
801+
"shape": "(1,2)",
802+
"axis_names": "('fsdp','tp')",
803+
}
804+
805+
fake_devices = ["a0", "a1", "r0", "r1"]
806+
allocated_devices = {
807+
"actor_model_config": ["a0", "a1"],
808+
"rollout_model_config": ["r0", "r1"],
809+
}
810+
created_mesh_devices = {}
811+
812+
def fake_create_mesh(model_key, devices=None):
813+
created_mesh_devices[model_key] = list(devices)
814+
return (model_key, tuple(devices))
815+
816+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
817+
with mock.patch.object(
818+
grpo_main.mesh_lib,
819+
"allocate_named_mesh_device_slices",
820+
return_value=allocated_devices,
821+
) as allocate_mock:
822+
with mock.patch.object(pipeline, "create_mesh", side_effect=fake_create_mesh):
823+
role_to_mesh = pipeline.create_role_to_mesh()
824+
825+
allocate_mock.assert_called_once_with(
826+
[
827+
("actor_model_config", 2),
828+
("rollout_model_config", 2),
829+
],
830+
devices=fake_devices,
831+
)
832+
self.assertEqual(created_mesh_devices["actor_model_config"], ["a0", "a1"])
833+
self.assertEqual(created_mesh_devices["rollout_model_config"], ["r0", "r1"])
834+
self.assertEqual(
835+
role_to_mesh[rl_cluster_lib.Role.ACTOR],
836+
("actor_model_config", ("a0", "a1")),
837+
)
838+
self.assertIs(
839+
role_to_mesh[rl_cluster_lib.Role.REFERENCE],
840+
role_to_mesh[rl_cluster_lib.Role.ACTOR],
841+
)
842+
self.assertEqual(
843+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT],
844+
("rollout_model_config", ("r0", "r1")),
845+
)
846+
756847

757848
if __name__ == "__main__":
758849
absltest.main()

0 commit comments

Comments
 (0)