Skip to content

Commit 130eb5f

Browse files
committed
Fix mesh creation.
1 parent fa7cabb commit 130eb5f

11 files changed

Lines changed: 3574 additions & 132 deletions

File tree

.github/workflows/cpu-tests.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ jobs:
6363
run: |
6464
python -m pytest tests/cli/utils/ -v --tb=short
6565
66+
- name: Run shared mesh and topology tests
67+
run: |
68+
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short
69+
6670
- name: Run perf tests
6771
run: |
6872
python -m pytest tests/perf/ -v --tb=short

.github/workflows/tpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
- name: Run tunix tests not covered by the above categories
121121
run: |
122122
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
123-
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
123+
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
124124
if [ "${code:-0}" = "5" ]; then
125125
echo "No tests collected (expected)."
126126
exit 0

tests/cli/config_test.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tunix.sft import peft_trainer
2828
from tunix.tests import test_common as tc
2929
from tunix.utils import env_utils
30+
from tunix.utils import mesh as mesh_lib
3031

3132

3233
class ConfigTest(parameterized.TestCase):
@@ -262,7 +263,7 @@ def test_learning_rate_schedule_valid(self, overrides):
262263
self.assertIsNotNone(lr_schedule)
263264
self.assertTrue(callable(lr_schedule), "lr_schedule should be callable")
264265

265-
# --- Tests for create_mesh ---
266+
# --- Tests for mesh config parsing and mesh creation ---
266267
@parameterized.named_parameters(
267268
dict(
268269
testcase_name="valid_1d",
@@ -311,40 +312,93 @@ def test_create_mesh_valid(
311312
):
312313
mock_device_count_fn.return_value = mock_num_devices
313314
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
314-
mesh = hp.create_mesh("model_config")
315-
self.assertEqual(
316-
mesh,
317-
jax.make_mesh(
318-
expected[0],
319-
expected[1],
320-
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
321-
),
322-
)
315+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
316+
expected_mesh = object()
323317

324-
def test_create_mesh_with_assigned_devices(self):
325-
raw_keys = {
326-
"model_config": {
327-
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
328-
}
329-
}
330-
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
331-
assigned_devices = ["d0", "d1", "d2", "d3"]
318+
with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock:
319+
mesh = mesh_lib.create_mesh(axis_shapes, axis_names)
332320

333-
class FakeMesh:
321+
make_mesh_mock.assert_called_once_with(
322+
expected[0],
323+
expected[1],
324+
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
325+
)
326+
self.assertIs(mesh, expected_mesh)
327+
328+
def test_create_mesh_with_assigned_devices(self):
329+
raw_keys = {
330+
"model_config": {
331+
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
332+
}
333+
}
334+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
335+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
336+
assigned_devices = ["d0", "d1", "d2", "d3"]
334337

335-
def __init__(self, devices, axis_names, axis_types=None):
336-
self.devices = devices
337-
self.axis_names = axis_names
338-
self.axis_types = axis_types
338+
class FakeMesh:
339339

340-
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
341-
mesh = hp.create_mesh("model_config", devices=assigned_devices)
340+
def __init__(self, devices, axis_names, axis_types=None):
341+
self.devices = devices
342+
self.axis_names = axis_names
343+
self.axis_types = axis_types
342344

343-
self.assertEqual(mesh.devices.shape, (2, 2))
344-
self.assertSequenceEqual(
345-
mesh.devices.flatten().tolist(), assigned_devices
345+
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
346+
mesh = mesh_lib.create_mesh(
347+
axis_shapes,
348+
axis_names,
349+
devices=assigned_devices,
346350
)
347-
self.assertEqual(mesh.axis_names, ("x", "y"))
351+
352+
self.assertEqual(mesh.devices.shape, (2, 2))
353+
self.assertSequenceEqual(
354+
mesh.devices.flatten().tolist(), assigned_devices
355+
)
356+
self.assertEqual(mesh.axis_names, ("x", "y"))
357+
358+
def test_parse_mesh_allocation_policy_defaults_to_compact(self):
359+
raw_keys = {
360+
"model_config": {
361+
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
362+
}
363+
}
364+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
365+
366+
self.assertEqual(
367+
hp._parse_mesh_allocation_policy("model_config"),
368+
mesh_lib.normalize_allocation_policy(None),
369+
)
370+
371+
def test_parse_mesh_allocation_policy_validates_explicit_value(self):
372+
raw_keys = {
373+
"model_config": {
374+
"mesh": {
375+
"shape": "(2, 2)",
376+
"axis_names": "('x', 'y')",
377+
"allocation_policy": "performance",
378+
}
379+
}
380+
}
381+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
382+
383+
self.assertEqual(
384+
hp._parse_mesh_allocation_policy("model_config"),
385+
"PERFORMANCE",
386+
)
387+
388+
def test_parse_mesh_allocation_policy_rejects_invalid_value(self):
389+
raw_keys = {
390+
"model_config": {
391+
"mesh": {
392+
"shape": "(2, 2)",
393+
"axis_names": "('x', 'y')",
394+
"allocation_policy": "fastest",
395+
}
396+
}
397+
}
398+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
399+
400+
with self.assertRaisesRegex(ValueError, "allocation_policy must be one of"):
401+
hp._parse_mesh_allocation_policy("model_config")
348402

349403
@parameterized.named_parameters(
350404
dict(
@@ -424,11 +478,12 @@ def test_create_mesh_invalid(
424478
mock_num_devices,
425479
error_regex,
426480
):
427-
mock_device_count_fn.return_value = mock_num_devices
428-
with self.assertRaisesRegex(ValueError, error_regex):
429-
nested_dict = self.convert_nested_dict_to_list(raw_keys)
430-
hp = self.initialize_config(nested_dict)
431-
hp.create_mesh("model_config")
481+
mock_device_count_fn.return_value = mock_num_devices
482+
with self.assertRaisesRegex(ValueError, error_regex):
483+
nested_dict = self.convert_nested_dict_to_list(raw_keys)
484+
hp = self.initialize_config(nested_dict)
485+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
486+
mesh_lib.create_mesh(axis_shapes, axis_names)
432487

433488
@parameterized.named_parameters(
434489
dict(

tests/cli/grpo_main_test.py

Lines changed: 93 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import os
2121
import pathlib
2222
import tempfile
23-
from typing import Any
24-
from typing import cast
2523
from unittest import mock
2624

2725
from absl.testing import absltest
@@ -573,26 +571,6 @@ def test_standard_grpo_kv_cache(self):
573571
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)
574572

575573

576-
class ComputeParamsTest(absltest.TestCase):
577-
578-
def test_compute_params_persists_dynamic_num_batches(self):
579-
pipeline = _make_pipeline("")
580-
pipeline.config["batch_size"] = 8
581-
pipeline.config["num_batches"] = 0
582-
pipeline.config["num_train_epochs"] = 1
583-
pipeline.config["train_fraction"] = 0.8
584-
rl_training_config = cast(dict[str, Any], pipeline.config["rl_training_config"])
585-
rl_training_config["max_steps"] = 0
586-
587-
raw_dataset = mock.Mock()
588-
raw_dataset.__len__ = mock.Mock(return_value=7473)
589-
590-
pipeline.compute_params(raw_dataset)
591-
592-
self.assertEqual(pipeline.config["num_batches"], 934)
593-
self.assertEqual(rl_training_config["max_steps"], 747)
594-
595-
596574
# ---------------------------------------------------------------------------
597575
# GRPOConfig construction
598576
# ---------------------------------------------------------------------------
@@ -710,7 +688,21 @@ def test_split_mesh_uses_explicit_role_meshes(self):
710688
"axis_names": "('fsdp','tp')",
711689
}
712690

713-
fake_devices = list(range(4))
691+
class FakeDevice:
692+
693+
def __init__(self, device_id, coords):
694+
self.id = device_id
695+
self.coords = coords
696+
self.process_index = 0
697+
self.slice_index = 0
698+
self.device_kind = "TPU v5e"
699+
700+
fake_devices = [
701+
FakeDevice(0, (0, 0)),
702+
FakeDevice(1, (1, 0)),
703+
FakeDevice(2, (0, 1)),
704+
FakeDevice(3, (1, 1)),
705+
]
714706

715707
class FakeMesh:
716708

@@ -726,11 +718,21 @@ def __init__(self, devices, axis_names, axis_types=None):
726718
role_to_mesh = pipeline.create_role_to_mesh()
727719

728720
self.assertSequenceEqual(
729-
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(),
721+
[
722+
device.id
723+
for device in role_to_mesh[rl_cluster_lib.Role.ACTOR]
724+
.devices.flatten()
725+
.tolist()
726+
],
730727
[0, 1],
731728
)
732729
self.assertSequenceEqual(
733-
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(),
730+
[
731+
device.id
732+
for device in role_to_mesh[rl_cluster_lib.Role.ROLLOUT]
733+
.devices.flatten()
734+
.tolist()
735+
],
734736
[2, 3],
735737
)
736738
self.assertEqual(
@@ -746,6 +748,72 @@ def __init__(self, devices, axis_names, axis_types=None):
746748
role_to_mesh[rl_cluster_lib.Role.ACTOR],
747749
)
748750

751+
def test_create_role_to_mesh_passes_configured_allocation_policy(self):
752+
extra = """
753+
training_mode: "agentic_grpo"
754+
verl_compatible: false
755+
chat_parser_config:
756+
type: "default"
757+
agent_class_path: null
758+
agent_kwargs: {}
759+
env_class_path: null
760+
env_kwargs: {}
761+
kubernetes_config: null
762+
agentic_grpo_config:
763+
num_generations: 2
764+
num_iterations: 1
765+
beta: 0.0
766+
epsilon: 0.2
767+
epsilon_high: 0.28
768+
system_prompt: ""
769+
max_concurrency: 1
770+
off_policy_steps: 0
771+
max_turns: 1
772+
sglang_jax_config:
773+
mem_fraction_static: 0.8
774+
vllm_config:
775+
hbm_utilization: 0.4
776+
"""
777+
pipeline = _make_pipeline(extra)
778+
actor_model_config = pipeline.config["actor_model_config"]
779+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
780+
actor_model_config["mesh"] = {
781+
"shape": "(2,1)",
782+
"axis_names": "('fsdp','tp')",
783+
"allocation_policy": "PERFORMANCE",
784+
}
785+
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
786+
rollout_model_config = pipeline.config["rollout_model_config"]
787+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
788+
rollout_model_config["mesh"] = {
789+
"shape": "(1,2)",
790+
"axis_names": "('fsdp','tp')",
791+
"allocation_policy": "PERFORMANCE",
792+
}
793+
794+
fake_devices = list(range(4))
795+
796+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
797+
with mock.patch.object(
798+
grpo_main.mesh_lib,
799+
"allocate_named_mesh_device_slices",
800+
return_value={
801+
"actor_model_config": [0, 1],
802+
"rollout_model_config": [2, 3],
803+
},
804+
) as allocate_mock, mock.patch.object(
805+
grpo_main.mesh_lib,
806+
"create_mesh",
807+
side_effect=[object(), object()],
808+
):
809+
pipeline.create_role_to_mesh()
810+
811+
allocate_mock.assert_called_once_with(
812+
[("actor_model_config", 2), ("rollout_model_config", 2)],
813+
devices=fake_devices,
814+
allocation_policy="PERFORMANCE",
815+
)
816+
749817

750818
if __name__ == "__main__":
751819
absltest.main()

0 commit comments

Comments
 (0)