Skip to content

Commit f6dabb1

Browse files
committed
Fix mesh creation.
1 parent 3ce9765 commit f6dabb1

11 files changed

Lines changed: 2929 additions & 121 deletions

File tree

.github/workflows/cpu-tests.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ on:
2222

2323
permissions:
2424
contents: read
25-
jobs:
25+
jobs:
2626
run:
2727
runs-on: ubuntu-latest
2828
steps:
@@ -63,6 +63,10 @@ jobs:
6363
run: |
6464
python -m pytest tests/cli/utils/ -v --tb=short
6565
66+
- name: Run shared mesh and topology tests
67+
run: |
68+
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short
69+
6670
- name: Run perf tests
6771
run: |
6872
python -m pytest tests/perf/ -v --tb=short

.github/workflows/tpu-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
- name: Run tunix tests not covered by the above categories
121121
run: |
122122
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
123-
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
123+
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
124124
if [ "${code:-0}" = "5" ]; then
125125
echo "No tests collected (expected)."
126126
exit 0

tests/cli/config_test.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from tunix.sft import peft_trainer
2828
from tunix.tests import test_common as tc
2929
from tunix.utils import env_utils
30+
from tunix.utils import mesh as mesh_lib
31+
32+
os.environ.setdefault("HF_TOKEN", "TestToken")
3033

3134

3235
class ConfigTest(parameterized.TestCase):
@@ -262,7 +265,7 @@ def test_learning_rate_schedule_valid(self, overrides):
262265
self.assertIsNotNone(lr_schedule)
263266
self.assertTrue(callable(lr_schedule), "lr_schedule should be callable")
264267

265-
# --- Tests for create_mesh ---
268+
# --- Tests for mesh config parsing and mesh creation ---
266269
@parameterized.named_parameters(
267270
dict(
268271
testcase_name="valid_1d",
@@ -311,40 +314,48 @@ def test_create_mesh_valid(
311314
):
312315
mock_device_count_fn.return_value = mock_num_devices
313316
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
314-
mesh = hp.create_mesh("model_config")
315-
self.assertEqual(
316-
mesh,
317-
jax.make_mesh(
318-
expected[0],
319-
expected[1],
320-
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
321-
),
322-
)
317+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
318+
expected_mesh = object()
323319

324-
def test_create_mesh_with_assigned_devices(self):
325-
raw_keys = {
326-
"model_config": {
327-
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
328-
}
329-
}
330-
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
331-
assigned_devices = ["d0", "d1", "d2", "d3"]
320+
with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock:
321+
mesh = mesh_lib.create_mesh(axis_shapes, axis_names)
332322

333-
class FakeMesh:
323+
make_mesh_mock.assert_called_once_with(
324+
expected[0],
325+
expected[1],
326+
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
327+
)
328+
self.assertIs(mesh, expected_mesh)
329+
330+
def test_create_mesh_with_assigned_devices(self):
331+
raw_keys = {
332+
"model_config": {
333+
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
334+
}
335+
}
336+
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
337+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
338+
assigned_devices = ["d0", "d1", "d2", "d3"]
334339

335-
def __init__(self, devices, axis_names, axis_types=None):
336-
self.devices = devices
337-
self.axis_names = axis_names
338-
self.axis_types = axis_types
340+
class FakeMesh:
339341

340-
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
341-
mesh = hp.create_mesh("model_config", devices=assigned_devices)
342+
def __init__(self, devices, axis_names, axis_types=None):
343+
self.devices = devices
344+
self.axis_names = axis_names
345+
self.axis_types = axis_types
342346

343-
self.assertEqual(mesh.devices.shape, (2, 2))
344-
self.assertSequenceEqual(
345-
mesh.devices.flatten().tolist(), assigned_devices
347+
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
348+
mesh = mesh_lib.create_mesh(
349+
axis_shapes,
350+
axis_names,
351+
devices=assigned_devices,
346352
)
347-
self.assertEqual(mesh.axis_names, ("x", "y"))
353+
354+
self.assertEqual(mesh.devices.shape, (2, 2))
355+
self.assertSequenceEqual(
356+
mesh.devices.flatten().tolist(), assigned_devices
357+
)
358+
self.assertEqual(mesh.axis_names, ("x", "y"))
348359

349360
@parameterized.named_parameters(
350361
dict(
@@ -424,11 +435,12 @@ def test_create_mesh_invalid(
424435
mock_num_devices,
425436
error_regex,
426437
):
427-
mock_device_count_fn.return_value = mock_num_devices
428-
with self.assertRaisesRegex(ValueError, error_regex):
429-
nested_dict = self.convert_nested_dict_to_list(raw_keys)
430-
hp = self.initialize_config(nested_dict)
431-
hp.create_mesh("model_config")
438+
mock_device_count_fn.return_value = mock_num_devices
439+
with self.assertRaisesRegex(ValueError, error_regex):
440+
nested_dict = self.convert_nested_dict_to_list(raw_keys)
441+
hp = self.initialize_config(nested_dict)
442+
axis_shapes, axis_names = hp._parse_mesh_config("model_config")
443+
mesh_lib.create_mesh(axis_shapes, axis_names)
432444

433445
@parameterized.named_parameters(
434446
dict(

tests/cli/grpo_main_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,107 @@ def __init__(self, devices, axis_names, axis_types=None):
753753
role_to_mesh[rl_cluster_lib.Role.ACTOR],
754754
)
755755

756+
def test_split_mesh_delegates_device_allocation_to_mesh_utils(self):
757+
extra = """
758+
training_mode: "agentic_grpo"
759+
data_module: "tunix.cli.recipes.deepscaler_data"
760+
apply_chat_template_to_dataset: false
761+
data_config:
762+
train_data_path: "gs://fake/train.json"
763+
eval_data_path: "gs://fake/eval.parquet"
764+
prompt_key: "prompts"
765+
reward_functions: []
766+
verl_compatible: false
767+
chat_parser_config:
768+
type: "default"
769+
agent_class_path: null
770+
agent_kwargs: {}
771+
env_class_path: null
772+
env_kwargs: {}
773+
kubernetes_config: null
774+
agentic_grpo_config:
775+
num_generations: 2
776+
num_iterations: 1
777+
beta: 0.0
778+
epsilon: 0.2
779+
epsilon_high: 0.28
780+
system_prompt: ""
781+
max_concurrency: 1
782+
off_policy_steps: 0
783+
max_turns: 1
784+
context_ratio: 1
785+
sglang_jax_config:
786+
mem_fraction_static: 0.8
787+
vllm_config:
788+
hbm_utilization: 0.4
789+
"""
790+
pipeline = _make_pipeline(extra)
791+
actor_model_config = pipeline.config["actor_model_config"]
792+
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
793+
actor_model_config["mesh"] = {
794+
"shape": "(1,2)",
795+
"axis_names": "('fsdp','tp')",
796+
}
797+
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
798+
rollout_model_config = pipeline.config["rollout_model_config"]
799+
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
800+
rollout_model_config["mesh"] = {
801+
"shape": "(1,2)",
802+
"axis_names": "('fsdp','tp')",
803+
}
804+
805+
fake_devices = ["a0", "a1", "r0", "r1"]
806+
allocated_devices = {
807+
"actor_model_config": ["a0", "a1"],
808+
"rollout_model_config": ["r0", "r1"],
809+
}
810+
created_mesh_devices = {}
811+
812+
def fake_create_mesh(axis_shapes, axis_names, devices=None):
813+
model_key = (
814+
"actor_model_config"
815+
if axis_shapes == (1, 2) and axis_names == ("fsdp", "tp")
816+
and "actor_model_config" not in created_mesh_devices
817+
else "rollout_model_config"
818+
)
819+
created_mesh_devices[model_key] = list(devices)
820+
return (model_key, tuple(devices))
821+
822+
with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
823+
with mock.patch.object(
824+
grpo_main.mesh_lib,
825+
"allocate_named_mesh_device_slices",
826+
return_value=allocated_devices,
827+
) as allocate_mock:
828+
with mock.patch.object(
829+
grpo_main.mesh_lib,
830+
"create_mesh",
831+
side_effect=fake_create_mesh,
832+
):
833+
role_to_mesh = pipeline.create_role_to_mesh()
834+
835+
allocate_mock.assert_called_once_with(
836+
[
837+
("actor_model_config", 2),
838+
("rollout_model_config", 2),
839+
],
840+
devices=fake_devices,
841+
)
842+
self.assertEqual(created_mesh_devices["actor_model_config"], ["a0", "a1"])
843+
self.assertEqual(created_mesh_devices["rollout_model_config"], ["r0", "r1"])
844+
self.assertEqual(
845+
role_to_mesh[rl_cluster_lib.Role.ACTOR],
846+
("actor_model_config", ("a0", "a1")),
847+
)
848+
self.assertIs(
849+
role_to_mesh[rl_cluster_lib.Role.REFERENCE],
850+
role_to_mesh[rl_cluster_lib.Role.ACTOR],
851+
)
852+
self.assertEqual(
853+
role_to_mesh[rl_cluster_lib.Role.ROLLOUT],
854+
("rollout_model_config", ("r0", "r1")),
855+
)
856+
756857

757858
if __name__ == "__main__":
758859
absltest.main()

0 commit comments

Comments
 (0)