Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ jobs:
run: |
python -m pytest tests/cli/utils/ -v --tb=short

- name: Run shared mesh and topology tests
run: |
python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short

- name: Run perf tests
run: |
python -m pytest tests/perf/ -v --tb=short
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ jobs:
- name: Run tunix tests not covered by the above categories
run: |
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$?
if [ "${code:-0}" = "5" ]; then
echo "No tests collected (expected)."
exit 0
Expand Down
9 changes: 7 additions & 2 deletions docs/launching.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ This section provides a detailed explanation of the configuration parameters ava

#### Model Configuration (`model_config`)

These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration.
These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration.

* **`model_name`**: The unique full name identifier of the model. This
corresponds to the full name and should match exactly with the model name
Expand Down Expand Up @@ -287,6 +287,11 @@ These parameters define the base model, where to download it from, and how to sh
* **`mesh`**: Defines the hardware mesh layout for distributed training.
* `shape`: Tuple string defining mesh dimensions (e.g., `"(2,2)"` for a 2x2 grid).
* `axis_names`: Names for mesh axes, often used for parallelism strategies (e.g., `"('fsdp','tp')"` for Fully Sharded Data Parallelism and Tensor Parallelism).
* `allocation_policy`: Optional policy controlling how Tunix carves this mesh from a larger device pool.
* `COMPACT`: Prefer the smallest fitting remaining coord region.
* `PERFORMANCE`: Prefer more cubical supported extracted shapes.
* If omitted, Tunix defaults to `COMPACT`.
* When multiple owned meshes are allocated together, they must all use the same `allocation_policy` or leave to defaults.


#### Tokenizer Configuration (`tokenizer_config`)
Expand Down Expand Up @@ -338,7 +343,7 @@ General settings for the training loop, logging, and checkpointing.

* **`eval_every_n_steps`**: Frequency of running evaluation steps.

* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients
* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients
before performing a parameter update (simulates larger batch sizes).

* **`checkpointing_options`**:
Expand Down
123 changes: 89 additions & 34 deletions tests/cli/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tunix.sft import peft_trainer
from tunix.tests import test_common as tc
from tunix.utils import env_utils
from tunix.utils import mesh as mesh_lib


class ConfigTest(parameterized.TestCase):
Expand Down Expand Up @@ -262,7 +263,7 @@ def test_learning_rate_schedule_valid(self, overrides):
self.assertIsNotNone(lr_schedule)
self.assertTrue(callable(lr_schedule), "lr_schedule should be callable")

# --- Tests for create_mesh ---
# --- Tests for mesh config parsing and mesh creation ---
@parameterized.named_parameters(
dict(
testcase_name="valid_1d",
Expand Down Expand Up @@ -311,40 +312,93 @@ def test_create_mesh_valid(
):
mock_device_count_fn.return_value = mock_num_devices
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
mesh = hp.create_mesh("model_config")
self.assertEqual(
mesh,
jax.make_mesh(
expected[0],
expected[1],
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
),
)
axis_shapes, axis_names = hp.parse_mesh_config("model_config")
expected_mesh = object()

def test_create_mesh_with_assigned_devices(self):
raw_keys = {
"model_config": {
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
}
}
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
assigned_devices = ["d0", "d1", "d2", "d3"]
with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock:
mesh = mesh_lib.create_mesh(axis_shapes, axis_names)

class FakeMesh:
make_mesh_mock.assert_called_once_with(
expected[0],
expected[1],
axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]),
)
self.assertIs(mesh, expected_mesh)

def test_create_mesh_with_assigned_devices(self):
raw_keys = {
"model_config": {
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
}
}
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))
axis_shapes, axis_names = hp.parse_mesh_config("model_config")
assigned_devices = ["d0", "d1", "d2", "d3"]

def __init__(self, devices, axis_names, axis_types=None):
self.devices = devices
self.axis_names = axis_names
self.axis_types = axis_types
class FakeMesh:

with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
mesh = hp.create_mesh("model_config", devices=assigned_devices)
def __init__(self, devices, axis_names, axis_types=None):
self.devices = devices
self.axis_names = axis_names
self.axis_types = axis_types

self.assertEqual(mesh.devices.shape, (2, 2))
self.assertSequenceEqual(
mesh.devices.flatten().tolist(), assigned_devices
with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh):
mesh = mesh_lib.create_mesh(
axis_shapes,
axis_names,
devices=assigned_devices,
)
self.assertEqual(mesh.axis_names, ("x", "y"))

self.assertEqual(mesh.devices.shape, (2, 2))
self.assertSequenceEqual(
mesh.devices.flatten().tolist(), assigned_devices
)
self.assertEqual(mesh.axis_names, ("x", "y"))

def test_parse_mesh_allocation_policy_defaults_to_compact(self):
raw_keys = {
"model_config": {
"mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"}
}
}
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))

self.assertEqual(
hp._parse_mesh_allocation_policy("model_config"),
mesh_lib.normalize_allocation_policy(None),
)

def test_parse_mesh_allocation_policy_validates_explicit_value(self):
raw_keys = {
"model_config": {
"mesh": {
"shape": "(2, 2)",
"axis_names": "('x', 'y')",
"allocation_policy": "performance",
}
}
}
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))

self.assertEqual(
hp._parse_mesh_allocation_policy("model_config"),
"PERFORMANCE",
)

def test_parse_mesh_allocation_policy_rejects_invalid_value(self):
raw_keys = {
"model_config": {
"mesh": {
"shape": "(2, 2)",
"axis_names": "('x', 'y')",
"allocation_policy": "fastest",
}
}
}
hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys))

with self.assertRaisesRegex(ValueError, "allocation_policy must be one of"):
hp._parse_mesh_allocation_policy("model_config")

@parameterized.named_parameters(
dict(
Expand Down Expand Up @@ -424,11 +478,12 @@ def test_create_mesh_invalid(
mock_num_devices,
error_regex,
):
mock_device_count_fn.return_value = mock_num_devices
with self.assertRaisesRegex(ValueError, error_regex):
nested_dict = self.convert_nested_dict_to_list(raw_keys)
hp = self.initialize_config(nested_dict)
hp.create_mesh("model_config")
mock_device_count_fn.return_value = mock_num_devices
with self.assertRaisesRegex(ValueError, error_regex):
nested_dict = self.convert_nested_dict_to_list(raw_keys)
hp = self.initialize_config(nested_dict)
axis_shapes, axis_names = hp.parse_mesh_config("model_config")
mesh_lib.create_mesh(axis_shapes, axis_names)

@parameterized.named_parameters(
dict(
Expand Down
118 changes: 93 additions & 25 deletions tests/cli/grpo_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import os
import pathlib
import tempfile
from typing import Any
from typing import cast
from unittest import mock

from absl.testing import absltest
Expand Down Expand Up @@ -573,26 +571,6 @@ def test_standard_grpo_kv_cache(self):
self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256)


class ComputeParamsTest(absltest.TestCase):

def test_compute_params_persists_dynamic_num_batches(self):
pipeline = _make_pipeline("")
pipeline.config["batch_size"] = 8
pipeline.config["num_batches"] = 0
pipeline.config["num_train_epochs"] = 1
pipeline.config["train_fraction"] = 0.8
rl_training_config = cast(dict[str, Any], pipeline.config["rl_training_config"])
rl_training_config["max_steps"] = 0

raw_dataset = mock.Mock()
raw_dataset.__len__ = mock.Mock(return_value=7473)

pipeline.compute_params(raw_dataset)

self.assertEqual(pipeline.config["num_batches"], 934)
self.assertEqual(rl_training_config["max_steps"], 747)


# ---------------------------------------------------------------------------
# GRPOConfig construction
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -710,7 +688,21 @@ def test_split_mesh_uses_explicit_role_meshes(self):
"axis_names": "('fsdp','tp')",
}

fake_devices = list(range(4))
class FakeDevice:

def __init__(self, device_id, coords):
self.id = device_id
self.coords = coords
self.process_index = 0
self.slice_index = 0
self.device_kind = "TPU v5e"

fake_devices = [
FakeDevice(0, (0, 0)),
FakeDevice(1, (1, 0)),
FakeDevice(2, (0, 1)),
FakeDevice(3, (1, 1)),
]

class FakeMesh:

Expand All @@ -726,11 +718,21 @@ def __init__(self, devices, axis_names, axis_types=None):
role_to_mesh = pipeline.create_role_to_mesh()

self.assertSequenceEqual(
role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(),
[
device.id
for device in role_to_mesh[rl_cluster_lib.Role.ACTOR]
.devices.flatten()
.tolist()
],
[0, 1],
)
self.assertSequenceEqual(
role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(),
[
device.id
for device in role_to_mesh[rl_cluster_lib.Role.ROLLOUT]
.devices.flatten()
.tolist()
],
[2, 3],
)
self.assertEqual(
Expand All @@ -746,6 +748,72 @@ def __init__(self, devices, axis_names, axis_types=None):
role_to_mesh[rl_cluster_lib.Role.ACTOR],
)

def test_create_role_to_mesh_passes_configured_allocation_policy(self):
extra = """
training_mode: "agentic_grpo"
verl_compatible: false
chat_parser_config:
type: "default"
agent_class_path: null
agent_kwargs: {}
env_class_path: null
env_kwargs: {}
kubernetes_config: null
agentic_grpo_config:
num_generations: 2
num_iterations: 1
beta: 0.0
epsilon: 0.2
epsilon_high: 0.28
system_prompt: ""
max_concurrency: 1
off_policy_steps: 0
max_turns: 1
sglang_jax_config:
mem_fraction_static: 0.8
vllm_config:
hbm_utilization: 0.4
"""
pipeline = _make_pipeline(extra)
actor_model_config = pipeline.config["actor_model_config"]
if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig):
actor_model_config["mesh"] = {
"shape": "(2,1)",
"axis_names": "('fsdp','tp')",
"allocation_policy": "PERFORMANCE",
}
pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"}
rollout_model_config = pipeline.config["rollout_model_config"]
if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig):
rollout_model_config["mesh"] = {
"shape": "(1,2)",
"axis_names": "('fsdp','tp')",
"allocation_policy": "PERFORMANCE",
}

fake_devices = list(range(4))

with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices):
with mock.patch.object(
grpo_main.mesh_lib,
"allocate_named_mesh_device_slices",
return_value={
"actor_model_config": [0, 1],
"rollout_model_config": [2, 3],
},
) as allocate_mock, mock.patch.object(
grpo_main.mesh_lib,
"create_mesh",
side_effect=[object(), object()],
):
pipeline.create_role_to_mesh()

allocate_mock.assert_called_once_with(
[("actor_model_config", 2), ("rollout_model_config", 2)],
devices=fake_devices,
allocation_policy="PERFORMANCE",
)


if __name__ == "__main__":
absltest.main()
Loading
Loading