Skip to content

Commit cc0de37

Browse files
committed
Fix mesh creation.
1 parent 43f9eaa commit cc0de37

11 files changed

Lines changed: 3792 additions & 130 deletions

File tree

.github/workflows/cpu-tests.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ jobs:
6363
run: |
6464
python -m pytest tests/cli/utils/ -v --tb=short
6565
66+
- name: Run shared mesh and topology tests
67+
run: |
68+
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short
69+
6670
- name: Run perf tests
6771
run: |
6872
python -m pytest tests/perf/ -v --tb=short

.github/workflows/tpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
- name: Run tunix tests not covered by the above categories
121121
run: |
122122
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
123-
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
123+
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
124124
if [ "${code:-0}" = "5" ]; then
125125
echo "No tests collected (expected)."
126126
exit 0

tests/cli/config_test.py

Lines changed: 91 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from tunix.sft import peft_trainer
2828
from tunix.tests import test_common as tc
2929
from tunix.utils import env_utils
30+
from tunix.utils import mesh as mesh_lib
31+
32+
os.environ.setdefault("HF_TOKEN", "TestToken")
3033

3134

3235
class ConfigTest(parameterized.TestCase):
@@ -262,7 +265,7 @@ def test_learning_rate_schedule_valid(self, overrides):
262265
self.assertIsNotNone(lr_schedule)
263266
self.assertTrue(callable(lr_schedule), "lr_schedule should be callable")
264267

265-
# --- Tests for create_mesh ---
268+
# --- Tests for mesh config parsing and mesh creation ---
266269
@parameterized.named_parameters(
267270
dict(
268271
testcase_name="valid_1d",
@@ -311,40 +314,93 @@ def test_create_mesh_valid(
311314
):
312315
mock_device_count_fn.return_value = mock_num_devices
313316
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
314-
mesh = hp.create_mesh("model_config")
315-
self.assertEqual(
316-
mesh,
317-
jax.make_mesh(
318-
expected[0],
319-
expected[1],
320-
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
321-
),
322-
)
317+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
318+
expected_mesh = object()
323319

324-
def test_create_mesh_with_assigned_devices(self):
325-
raw_keys = {
326-
"model_config": {
327-
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
328-
}
329-
}
330-
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
331-
assigned_devices = ["d0", "d1", "d2", "d3"]
320+
with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock:
321+
mesh = mesh_lib.create_mesh(axis_shapes, axis_names)
332322

333-
class FakeMesh:
323+
make_mesh_mock.assert_called_once_with(
324+
expected[0],
325+
expected[1],
326+
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
327+
)
328+
self.assertIs(mesh, expected_mesh)
329+
330+
def test_create_mesh_with_assigned_devices(self):
331+
raw_keys = {
332+
"model_config": {
333+
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
334+
}
335+
}
336+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
337+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
338+
assigned_devices = ["d0", "d1", "d2", "d3"]
334339

335-
def __init__(self, devices, axis_names, axis_types=None):
336-
self.devices = devices
337-
self.axis_names = axis_names
338-
self.axis_types = axis_types
340+
class FakeMesh:
339341

340-
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
341-
mesh = hp.create_mesh("model_config", devices=assigned_devices)
342+
def __init__(self, devices, axis_names, axis_types=None):
343+
self.devices = devices
344+
self.axis_names = axis_names
345+
self.axis_types = axis_types
342346

343-
self.assertEqual(mesh.devices.shape, (2, 2))
344-
self.assertSequenceEqual(
345-
mesh.devices.flatten().tolist(), assigned_devices
347+
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
348+
mesh = mesh_lib.create_mesh(
349+
axis_shapes,
350+
axis_names,
351+
devices=assigned_devices,
346352
)
347-
self.assertEqual(mesh.axis_names, ("x", "y"))
353+
354+
self.assertEqual(mesh.devices.shape, (2, 2))
355+
self.assertSequenceEqual(
356+
mesh.devices.flatten().tolist(), assigned_devices
357+
)
358+
self.assertEqual(mesh.axis_names, ("x", "y"))
359+
360+
def test_parse_mesh_allocation_policy_defaults_to_compact(self):
361+
raw_keys = {
362+
"model_config": {
363+
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
364+
}
365+
}
366+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
367+
368+
self.assertEqual(
369+
hp._parse_mesh_allocation_policy("model_config"),
370+
mesh_lib.normalize_allocation_policy(None),
371+
)
372+
373+
def test_parse_mesh_allocation_policy_validates_explicit_value(self):
374+
raw_keys = {
375+
"model_config": {
376+
"mesh": {
377+
"shape": "(2, 2)",
378+
"axis_names": "('x', 'y')",
379+
"allocation_policy": "performance",
380+
}
381+
}
382+
}
383+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
384+
385+
self.assertEqual(
386+
hp._parse_mesh_allocation_policy("model_config"),
387+
"PERFORMANCE",
388+
)
389+
390+
def test_parse_mesh_allocation_policy_rejects_invalid_value(self):
391+
raw_keys = {
392+
"model_config": {
393+
"mesh": {
394+
"shape": "(2, 2)",
395+
"axis_names": "('x', 'y')",
396+
"allocation_policy": "fastest",
397+
}
398+
}
399+
}
400+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
401+
402+
with self.assertRaisesRegex(ValueError, "allocation_policy must be one of"):
403+
hp._parse_mesh_allocation_policy("model_config")
348404

349405
@parameterized.named_parameters(
350406
dict(
@@ -424,11 +480,12 @@ def test_create_mesh_invalid(
424480
mock_num_devices,
425481
error_regex,
426482
):
427-
mock_device_count_fn.return_value = mock_num_devices
428-
with self.assertRaisesRegex(ValueError, error_regex):
429-
nested_dict = self.convert_nested_dict_to_list(raw_keys)
430-
hp = self.initialize_config(nested_dict)
431-
hp.create_mesh("model_config")
483+
mock_device_count_fn.return_value = mock_num_devices
484+
with self.assertRaisesRegex(ValueError, error_regex):
485+
nested_dict = self.convert_nested_dict_to_list(raw_keys)
486+
hp = self.initialize_config(nested_dict)
487+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
488+
mesh_lib.create_mesh(axis_shapes, axis_names)
432489

433490
@parameterized.named_parameters(
434491
dict(

tests/cli/grpo_main_test.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import os
2121
import pathlib
2222
import tempfile
23-
from typing import Any
24-
from typing import cast
2523
from unittest import mock
2624

2725
from absl.testing import absltest
@@ -559,7 +557,8 @@ def test_single_turn_kv_cache(self):
559557
def test_multi_turn_kv_cache(self):
560558
p = self._make_agentic_pipeline(max_turns=20, context_ratio=2)
561559
cfg = p.create_rollout_config()
562-
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)
560+
# max_prompt=256, max_response=512, 20 turns * ratio 2
561+
self.assertEqual(cfg.kv_cache_size, 256 + 512 * 2 * 20)
563562

564563
def test_standard_grpo_kv_cache(self):
565564
extra = """
@@ -579,26 +578,6 @@ def test_standard_grpo_kv_cache(self):
579578
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)
580579

581580

582-
class ComputeParamsTest(absltest.TestCase):
583-
584-
def test_compute_params_persists_dynamic_num_batches(self):
585-
pipeline = _make_pipeline("")
586-
pipeline.config["batch_size"] = 8
587-
pipeline.config["num_batches"] = 0
588-
pipeline.config["num_train_epochs"] = 1
589-
pipeline.config["train_fraction"] = 0.8
590-
rl_training_config = cast(dict[str, Any], pipeline.config["rl_training_config"])
591-
rl_training_config["max_steps"] = 0
592-
593-
raw_dataset = mock.Mock()
594-
raw_dataset.__len__ = mock.Mock(return_value=7473)
595-
596-
pipeline.compute_params(raw_dataset)
597-
598-
self.assertEqual(pipeline.config["num_batches"], 934)
599-
self.assertEqual(rl_training_config["max_steps"], 747)
600-
601-
602581
# ---------------------------------------------------------------------------
603582
# GRPOConfig construction
604583
# ---------------------------------------------------------------------------
@@ -753,6 +732,73 @@ def __init__(self, devices, axis_names, axis_types=None):
753732
role_to_mesh[rl_cluster_lib.Role.ACTOR],
754733
)
755734

735+
def test_create_role_to_mesh_passes_configured_allocation_policy(self):
736+
extra = """
737+
training_mode: "agentic_grpo"
738+
verl_compatible: false
739+
chat_parser_config:
740+
type: "default"
741+
agent_class_path: null
742+
agent_kwargs: {}
743+
env_class_path: null
744+
env_kwargs: {}
745+
kubernetes_config: null
746+
agentic_grpo_config:
747+
num_generations: 2
748+
num_iterations: 1
749+
beta: 0.0
750+
epsilon: 0.2
751+
epsilon_high: 0.28
752+
system_prompt: ""
753+
max_concurrency: 1
754+
off_policy_steps: 0
755+
max_turns: 1
756+
context_ratio: 1
757+
sglang_jax_config:
758+
mem_fraction_static: 0.8
759+
vllm_config:
760+
hbm_utilization: 0.4
761+
"""
762+
pipeline = _make_pipeline(extra)
763+
actor_model_config = pipeline.config["actor_model_config"]
764+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
765+
actor_model_config["mesh"] = {
766+
"shape": "(2,1)",
767+
"axis_names": "('fsdp','tp')",
768+
"allocation_policy": "PERFORMANCE",
769+
}
770+
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
771+
rollout_model_config = pipeline.config["rollout_model_config"]
772+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
773+
rollout_model_config["mesh"] = {
774+
"shape": "(1,2)",
775+
"axis_names": "('fsdp','tp')",
776+
"allocation_policy": "PERFORMANCE",
777+
}
778+
779+
fake_devices = list(range(4))
780+
781+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
782+
with mock.patch.object(
783+
grpo_main.mesh_lib,
784+
"allocate_named_mesh_device_slices",
785+
return_value={
786+
"actor_model_config": [0, 1],
787+
"rollout_model_config": [2, 3],
788+
},
789+
) as allocate_mock, mock.patch.object(
790+
grpo_main.mesh_lib,
791+
"create_mesh",
792+
side_effect=[object(), object()],
793+
):
794+
pipeline.create_role_to_mesh()
795+
796+
allocate_mock.assert_called_once_with(
797+
[("actor_model_config", 2), ("rollout_model_config", 2)],
798+
devices=fake_devices,
799+
allocation_policy="PERFORMANCE",
800+
)
801+
756802

757803
if __name__ == "__main__":
758804
absltest.main()

0 commit comments

Comments
 (0)