Skip to content

Commit 6e9c721

Browse files
committed
Fix mesh creation.
1 parent e67d0ef commit 6e9c721

11 files changed

Lines changed: 3550 additions & 130 deletions

File tree

.github/workflows/cpu-tests.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ jobs:
6363
run: |
6464
python -m pytest tests/cli/utils/ -v --tb=short
6565
66+
- name: Run shared mesh and topology tests
67+
run: |
68+
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short
69+
6670
- name: Run perf tests
6771
run: |
6872
python -m pytest tests/perf/ -v --tb=short

.github/workflows/tpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
- name: Run tunix tests not covered by the above categories
121121
run: |
122122
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
123-
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
123+
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
124124
if [ "${code:-0}" = "5" ]; then
125125
echo "No tests collected (expected)."
126126
exit 0

tests/cli/config_test.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tunix.sft import peft_trainer
2828
from tunix.tests import test_common as tc
2929
from tunix.utils import env_utils
30+
from tunix.utils import mesh as mesh_lib
3031

3132

3233
class ConfigTest(parameterized.TestCase):
@@ -262,7 +263,7 @@ def test_learning_rate_schedule_valid(self, overrides):
262263
self.assertIsNotNone(lr_schedule)
263264
self.assertTrue(callable(lr_schedule), "lr_schedule should be callable")
264265

265-
# --- Tests for create_mesh ---
266+
# --- Tests for mesh config parsing and mesh creation ---
266267
@parameterized.named_parameters(
267268
dict(
268269
testcase_name="valid_1d",
@@ -311,40 +312,93 @@ def test_create_mesh_valid(
311312
):
312313
mock_device_count_fn.return_value = mock_num_devices
313314
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
314-
mesh = hp.create_mesh("model_config")
315-
self.assertEqual(
316-
mesh,
317-
jax.make_mesh(
318-
expected[0],
319-
expected[1],
320-
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
321-
),
322-
)
315+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
316+
expected_mesh = object()
323317

324-
def test_create_mesh_with_assigned_devices(self):
325-
raw_keys = {
326-
"model_config": {
327-
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
328-
}
329-
}
330-
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
331-
assigned_devices = ["d0", "d1", "d2", "d3"]
318+
with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock:
319+
mesh = mesh_lib.create_mesh(axis_shapes, axis_names)
332320

333-
class FakeMesh:
321+
make_mesh_mock.assert_called_once_with(
322+
expected[0],
323+
expected[1],
324+
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
325+
)
326+
self.assertIs(mesh, expected_mesh)
327+
328+
def test_create_mesh_with_assigned_devices(self):
329+
raw_keys = {
330+
"model_config": {
331+
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
332+
}
333+
}
334+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
335+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
336+
assigned_devices = ["d0", "d1", "d2", "d3"]
334337

335-
def __init__(self, devices, axis_names, axis_types=None):
336-
self.devices = devices
337-
self.axis_names = axis_names
338-
self.axis_types = axis_types
338+
class FakeMesh:
339339

340-
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
341-
mesh = hp.create_mesh("model_config", devices=assigned_devices)
340+
def __init__(self, devices, axis_names, axis_types=None):
341+
self.devices = devices
342+
self.axis_names = axis_names
343+
self.axis_types = axis_types
342344

343-
self.assertEqual(mesh.devices.shape, (2, 2))
344-
self.assertSequenceEqual(
345-
mesh.devices.flatten().tolist(), assigned_devices
345+
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
346+
mesh = mesh_lib.create_mesh(
347+
axis_shapes,
348+
axis_names,
349+
devices=assigned_devices,
346350
)
347-
self.assertEqual(mesh.axis_names, ("x", "y"))
351+
352+
self.assertEqual(mesh.devices.shape, (2, 2))
353+
self.assertSequenceEqual(
354+
mesh.devices.flatten().tolist(), assigned_devices
355+
)
356+
self.assertEqual(mesh.axis_names, ("x", "y"))
357+
358+
def test_parse_mesh_allocation_policy_defaults_to_compact(self):
359+
raw_keys = {
360+
"model_config": {
361+
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
362+
}
363+
}
364+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
365+
366+
self.assertEqual(
367+
hp._parse_mesh_allocation_policy("model_config"),
368+
mesh_lib.normalize_allocation_policy(None),
369+
)
370+
371+
def test_parse_mesh_allocation_policy_validates_explicit_value(self):
372+
raw_keys = {
373+
"model_config": {
374+
"mesh": {
375+
"shape": "(2, 2)",
376+
"axis_names": "('x', 'y')",
377+
"allocation_policy": "performance",
378+
}
379+
}
380+
}
381+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
382+
383+
self.assertEqual(
384+
hp._parse_mesh_allocation_policy("model_config"),
385+
"PERFORMANCE",
386+
)
387+
388+
def test_parse_mesh_allocation_policy_rejects_invalid_value(self):
389+
raw_keys = {
390+
"model_config": {
391+
"mesh": {
392+
"shape": "(2, 2)",
393+
"axis_names": "('x', 'y')",
394+
"allocation_policy": "fastest",
395+
}
396+
}
397+
}
398+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
399+
400+
with self.assertRaisesRegex(ValueError, "allocation_policy must be one of"):
401+
hp._parse_mesh_allocation_policy("model_config")
348402

349403
@parameterized.named_parameters(
350404
dict(
@@ -424,11 +478,12 @@ def test_create_mesh_invalid(
424478
mock_num_devices,
425479
error_regex,
426480
):
427-
mock_device_count_fn.return_value = mock_num_devices
428-
with self.assertRaisesRegex(ValueError, error_regex):
429-
nested_dict = self.convert_nested_dict_to_list(raw_keys)
430-
hp = self.initialize_config(nested_dict)
431-
hp.create_mesh("model_config")
481+
mock_device_count_fn.return_value = mock_num_devices
482+
with self.assertRaisesRegex(ValueError, error_regex):
483+
nested_dict = self.convert_nested_dict_to_list(raw_keys)
484+
hp = self.initialize_config(nested_dict)
485+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
486+
mesh_lib.create_mesh(axis_shapes, axis_names)
432487

433488
@parameterized.named_parameters(
434489
dict(

tests/cli/grpo_main_test.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import os
2121
import pathlib
2222
import tempfile
23-
from typing import Any
24-
from typing import cast
2523
from unittest import mock
2624

2725
from absl.testing import absltest
@@ -559,7 +557,8 @@ def test_single_turn_kv_cache(self):
559557
def test_multi_turn_kv_cache(self):
560558
p = self._make_agentic_pipeline(max_turns=20, context_ratio=2)
561559
cfg = p.create_rollout_config()
562-
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)
560+
# max_prompt=256, max_response=512, 20 turns * ratio 2
561+
self.assertEqual(cfg.kv_cache_size, 256 + 512 * 2 * 20)
563562

564563
def test_standard_grpo_kv_cache(self):
565564
extra = """
@@ -579,26 +578,6 @@ def test_standard_grpo_kv_cache(self):
579578
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)
580579

581580

582-
class ComputeParamsTest(absltest.TestCase):
583-
584-
def test_compute_params_persists_dynamic_num_batches(self):
585-
pipeline = _make_pipeline("")
586-
pipeline.config["batch_size"] = 8
587-
pipeline.config["num_batches"] = 0
588-
pipeline.config["num_train_epochs"] = 1
589-
pipeline.config["train_fraction"] = 0.8
590-
rl_training_config = cast(dict[str, Any], pipeline.config["rl_training_config"])
591-
rl_training_config["max_steps"] = 0
592-
593-
raw_dataset = mock.Mock()
594-
raw_dataset.__len__ = mock.Mock(return_value=7473)
595-
596-
pipeline.compute_params(raw_dataset)
597-
598-
self.assertEqual(pipeline.config["num_batches"], 934)
599-
self.assertEqual(rl_training_config["max_steps"], 747)
600-
601-
602581
# ---------------------------------------------------------------------------
603582
# GRPOConfig construction
604583
# ---------------------------------------------------------------------------
@@ -753,6 +732,73 @@ def __init__(self, devices, axis_names, axis_types=None):
753732
role_to_mesh[rl_cluster_lib.Role.ACTOR],
754733
)
755734

735+
def test_create_role_to_mesh_passes_configured_allocation_policy(self):
736+
extra = """
737+
training_mode: "agentic_grpo"
738+
verl_compatible: false
739+
chat_parser_config:
740+
type: "default"
741+
agent_class_path: null
742+
agent_kwargs: {}
743+
env_class_path: null
744+
env_kwargs: {}
745+
kubernetes_config: null
746+
agentic_grpo_config:
747+
num_generations: 2
748+
num_iterations: 1
749+
beta: 0.0
750+
epsilon: 0.2
751+
epsilon_high: 0.28
752+
system_prompt: ""
753+
max_concurrency: 1
754+
off_policy_steps: 0
755+
max_turns: 1
756+
context_ratio: 1
757+
sglang_jax_config:
758+
mem_fraction_static: 0.8
759+
vllm_config:
760+
hbm_utilization: 0.4
761+
"""
762+
pipeline = _make_pipeline(extra)
763+
actor_model_config = pipeline.config["actor_model_config"]
764+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
765+
actor_model_config["mesh"] = {
766+
"shape": "(2,1)",
767+
"axis_names": "('fsdp','tp')",
768+
"allocation_policy": "PERFORMANCE",
769+
}
770+
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
771+
rollout_model_config = pipeline.config["rollout_model_config"]
772+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
773+
rollout_model_config["mesh"] = {
774+
"shape": "(1,2)",
775+
"axis_names": "('fsdp','tp')",
776+
"allocation_policy": "PERFORMANCE",
777+
}
778+
779+
fake_devices = list(range(4))
780+
781+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
782+
with mock.patch.object(
783+
grpo_main.mesh_lib,
784+
"allocate_named_mesh_device_slices",
785+
return_value={
786+
"actor_model_config": [0, 1],
787+
"rollout_model_config": [2, 3],
788+
},
789+
) as allocate_mock, mock.patch.object(
790+
grpo_main.mesh_lib,
791+
"create_mesh",
792+
side_effect=[object(), object()],
793+
):
794+
pipeline.create_role_to_mesh()
795+
796+
allocate_mock.assert_called_once_with(
797+
[("actor_model_config", 2), ("rollout_model_config", 2)],
798+
devices=fake_devices,
799+
allocation_policy="PERFORMANCE",
800+
)
801+
756802

757803
if __name__ == "__main__":
758804
absltest.main()

0 commit comments

Comments
 (0)