diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index ce1fc997d..ac7da6a94 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -63,6 +63,10 @@ jobs: run: | python -m pytest tests/cli/utils/ -v --tb=short + - name: Run shared mesh and topology tests + run: | + python -m pytest tests/utils/mesh_utils_test.py tests/utils/topology_test.py -v --tb=short + - name: Run perf tests run: | python -m pytest tests/perf/ -v --tb=short diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml index 73694881a..f1fa7561d 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml @@ -120,7 +120,7 @@ jobs: - name: Run tunix tests not covered by the above categories run: | # This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here. - python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$? + python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/utils/mesh_utils_test.py --ignore=tests/utils/topology_test.py --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ --ignore=tests/smoke_tests/ || code=$? if [ "${code:-0}" = "5" ]; then echo "No tests collected (expected)." exit 0 diff --git a/docs/launching.md b/docs/launching.md index f309fbef9..a2abdbd3d 100644 --- a/docs/launching.md +++ b/docs/launching.md @@ -254,7 +254,7 @@ This section provides a detailed explanation of the configuration parameters ava #### Model Configuration (`model_config`) -These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration. +These parameters define the base model, where to download it from, and how to shard it across TPUs/GPUs. Note that `actor_model_config`, `reference_model_config`, and `rollout_model_config` typically inherit from this base configuration. * **`model_name`**: The unique full name identifier of the model. This corresponds to the full name and should match exactly with the model name @@ -287,6 +287,11 @@ These parameters define the base model, where to download it from, and how to sh * **`mesh`**: Defines the hardware mesh layout for distributed training. * `shape`: Tuple string defining mesh dimensions (e.g., `"(2,2)"` for a 2x2 grid). * `axis_names`: Names for mesh axes, often used for parallelism strategies (e.g., `"('fsdp','tp')"` for Fully Sharded Data Parallelism and Tensor Parallelism). + * `allocation_policy`: Optional policy controlling how Tunix carves this mesh from a larger device pool. + * `COMPACT`: Prefer the smallest fitting remaining coord region. + * `PERFORMANCE`: Prefer more cubical supported extracted shapes. + * If omitted, Tunix defaults to `COMPACT`. + * When multiple owned meshes are allocated together, they must all use the same `allocation_policy` or leave to defaults. #### Tokenizer Configuration (`tokenizer_config`) @@ -338,7 +343,7 @@ General settings for the training loop, logging, and checkpointing. * **`eval_every_n_steps`**: Frequency of running evaluation steps. -* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients +* **`gradient_accumulation_steps`**: Number of steps to accumulate gradients before performing a parameter update (simulates larger batch sizes). * **`checkpointing_options`**: diff --git a/tests/cli/config_test.py b/tests/cli/config_test.py index 06e24f656..35e4c86aa 100644 --- a/tests/cli/config_test.py +++ b/tests/cli/config_test.py @@ -27,6 +27,7 @@ from tunix.sft import peft_trainer from tunix.tests import test_common as tc from tunix.utils import env_utils +from tunix.utils import mesh as mesh_lib class ConfigTest(parameterized.TestCase): @@ -262,7 +263,7 @@ def test_learning_rate_schedule_valid(self, overrides): self.assertIsNotNone(lr_schedule) self.assertTrue(callable(lr_schedule), "lr_schedule should be callable") - # --- Tests for create_mesh --- + # --- Tests for mesh config parsing and mesh creation --- @parameterized.named_parameters( dict( testcase_name="valid_1d", @@ -311,40 +312,93 @@ def test_create_mesh_valid( ): mock_device_count_fn.return_value = mock_num_devices hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) - mesh = hp.create_mesh("model_config") - self.assertEqual( - mesh, - jax.make_mesh( - expected[0], - expected[1], - axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), - ), - ) + axis_shapes, axis_names = hp.parse_mesh_config("model_config") + expected_mesh = object() - def test_create_mesh_with_assigned_devices(self): - raw_keys = { - "model_config": { - "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} - } - } - hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) - assigned_devices = ["d0", "d1", "d2", "d3"] + with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock: + mesh = mesh_lib.create_mesh(axis_shapes, axis_names) - class FakeMesh: + make_mesh_mock.assert_called_once_with( + expected[0], + expected[1], + axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), + ) + self.assertIs(mesh, expected_mesh) + + def test_create_mesh_with_assigned_devices(self): + raw_keys = { + "model_config": { + "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} + } + } + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) + axis_shapes, axis_names = hp.parse_mesh_config("model_config") + assigned_devices = ["d0", "d1", "d2", "d3"] - def __init__(self, devices, axis_names, axis_types=None): - self.devices = devices - self.axis_names = axis_names - self.axis_types = axis_types + class FakeMesh: - with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): - mesh = hp.create_mesh("model_config", devices=assigned_devices) + def __init__(self, devices, axis_names, axis_types=None): + self.devices = devices + self.axis_names = axis_names + self.axis_types = axis_types - self.assertEqual(mesh.devices.shape, (2, 2)) - self.assertSequenceEqual( - mesh.devices.flatten().tolist(), assigned_devices + with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): + mesh = mesh_lib.create_mesh( + axis_shapes, + axis_names, + devices=assigned_devices, ) - self.assertEqual(mesh.axis_names, ("x", "y")) + + self.assertEqual(mesh.devices.shape, (2, 2)) + self.assertSequenceEqual( + mesh.devices.flatten().tolist(), assigned_devices + ) + self.assertEqual(mesh.axis_names, ("x", "y")) + + def test_parse_mesh_allocation_policy_defaults_to_compact(self): + raw_keys = { + "model_config": { + "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} + } + } + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) + + self.assertEqual( + hp._parse_mesh_allocation_policy("model_config"), + mesh_lib.normalize_allocation_policy(None), + ) + + def test_parse_mesh_allocation_policy_validates_explicit_value(self): + raw_keys = { + "model_config": { + "mesh": { + "shape": "(2, 2)", + "axis_names": "('x', 'y')", + "allocation_policy": "performance", + } + } + } + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) + + self.assertEqual( + hp._parse_mesh_allocation_policy("model_config"), + "PERFORMANCE", + ) + + def test_parse_mesh_allocation_policy_rejects_invalid_value(self): + raw_keys = { + "model_config": { + "mesh": { + "shape": "(2, 2)", + "axis_names": "('x', 'y')", + "allocation_policy": "fastest", + } + } + } + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) + + with self.assertRaisesRegex(ValueError, "allocation_policy must be one of"): + hp._parse_mesh_allocation_policy("model_config") @parameterized.named_parameters( dict( @@ -424,11 +478,12 @@ def test_create_mesh_invalid( mock_num_devices, error_regex, ): - mock_device_count_fn.return_value = mock_num_devices - with self.assertRaisesRegex(ValueError, error_regex): - nested_dict = self.convert_nested_dict_to_list(raw_keys) - hp = self.initialize_config(nested_dict) - hp.create_mesh("model_config") + mock_device_count_fn.return_value = mock_num_devices + with self.assertRaisesRegex(ValueError, error_regex): + nested_dict = self.convert_nested_dict_to_list(raw_keys) + hp = self.initialize_config(nested_dict) + axis_shapes, axis_names = hp.parse_mesh_config("model_config") + mesh_lib.create_mesh(axis_shapes, axis_names) @parameterized.named_parameters( dict( diff --git a/tests/cli/grpo_main_test.py b/tests/cli/grpo_main_test.py index 283b7742e..495caa8ed 100644 --- a/tests/cli/grpo_main_test.py +++ b/tests/cli/grpo_main_test.py @@ -20,8 +20,6 @@ import os import pathlib import tempfile -from typing import Any -from typing import cast from unittest import mock from absl.testing import absltest @@ -573,26 +571,6 @@ def test_standard_grpo_kv_cache(self): self.assertEqual(cfg.kv_cache_size, 256 + 512 + 256) -class ComputeParamsTest(absltest.TestCase): - - def test_compute_params_persists_dynamic_num_batches(self): - pipeline = _make_pipeline("") - pipeline.config["batch_size"] = 8 - pipeline.config["num_batches"] = 0 - pipeline.config["num_train_epochs"] = 1 - pipeline.config["train_fraction"] = 0.8 - rl_training_config = cast(dict[str, Any], pipeline.config["rl_training_config"]) - rl_training_config["max_steps"] = 0 - - raw_dataset = mock.Mock() - raw_dataset.__len__ = mock.Mock(return_value=7473) - - pipeline.compute_params(raw_dataset) - - self.assertEqual(pipeline.config["num_batches"], 934) - self.assertEqual(rl_training_config["max_steps"], 747) - - # --------------------------------------------------------------------------- # GRPOConfig construction # --------------------------------------------------------------------------- @@ -710,7 +688,21 @@ def test_split_mesh_uses_explicit_role_meshes(self): "axis_names": "('fsdp','tp')", } - fake_devices = list(range(4)) + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.process_index = 0 + self.slice_index = 0 + self.device_kind = "TPU v5e" + + fake_devices = [ + FakeDevice(0, (0, 0)), + FakeDevice(1, (1, 0)), + FakeDevice(2, (0, 1)), + FakeDevice(3, (1, 1)), + ] class FakeMesh: @@ -726,11 +718,21 @@ def __init__(self, devices, axis_names, axis_types=None): role_to_mesh = pipeline.create_role_to_mesh() self.assertSequenceEqual( - role_to_mesh[rl_cluster_lib.Role.ACTOR].devices.flatten().tolist(), + [ + device.id + for device in role_to_mesh[rl_cluster_lib.Role.ACTOR] + .devices.flatten() + .tolist() + ], [0, 1], ) self.assertSequenceEqual( - role_to_mesh[rl_cluster_lib.Role.ROLLOUT].devices.flatten().tolist(), + [ + device.id + for device in role_to_mesh[rl_cluster_lib.Role.ROLLOUT] + .devices.flatten() + .tolist() + ], [2, 3], ) self.assertEqual( @@ -746,6 +748,72 @@ def __init__(self, devices, axis_names, axis_types=None): role_to_mesh[rl_cluster_lib.Role.ACTOR], ) + def test_create_role_to_mesh_passes_configured_allocation_policy(self): + extra = """ +training_mode: "agentic_grpo" +verl_compatible: false +chat_parser_config: + type: "default" +agent_class_path: null +agent_kwargs: {} +env_class_path: null +env_kwargs: {} +kubernetes_config: null +agentic_grpo_config: + num_generations: 2 + num_iterations: 1 + beta: 0.0 + epsilon: 0.2 + epsilon_high: 0.28 + system_prompt: "" + max_concurrency: 1 + off_policy_steps: 0 + max_turns: 1 +sglang_jax_config: + mem_fraction_static: 0.8 +vllm_config: + hbm_utilization: 0.4 +""" + pipeline = _make_pipeline(extra) + actor_model_config = pipeline.config["actor_model_config"] + if isinstance(actor_model_config, omegaconf.dictconfig.DictConfig): + actor_model_config["mesh"] = { + "shape": "(2,1)", + "axis_names": "('fsdp','tp')", + "allocation_policy": "PERFORMANCE", + } + pipeline.config["reference_model_config"] = {"same_mesh_as": "actor"} + rollout_model_config = pipeline.config["rollout_model_config"] + if isinstance(rollout_model_config, omegaconf.dictconfig.DictConfig): + rollout_model_config["mesh"] = { + "shape": "(1,2)", + "axis_names": "('fsdp','tp')", + "allocation_policy": "PERFORMANCE", + } + + fake_devices = list(range(4)) + + with mock.patch.object(grpo_main.jax, "devices", return_value=fake_devices): + with mock.patch.object( + grpo_main.mesh_lib, + "allocate_named_mesh_device_slices", + return_value={ + "actor_model_config": [0, 1], + "rollout_model_config": [2, 3], + }, + ) as allocate_mock, mock.patch.object( + grpo_main.mesh_lib, + "create_mesh", + side_effect=[object(), object()], + ): + pipeline.create_role_to_mesh() + + allocate_mock.assert_called_once_with( + [("actor_model_config", 2), ("rollout_model_config", 2)], + devices=fake_devices, + allocation_policy="PERFORMANCE", + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/utils/mesh_utils_test.py b/tests/utils/mesh_utils_test.py new file mode 100644 index 000000000..596a5f693 --- /dev/null +++ b/tests/utils/mesh_utils_test.py @@ -0,0 +1,1278 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +import jax +from tunix.utils import mesh + + +class MeshUtilsTest(absltest.TestCase): + + def test_device_host_key_prefers_slice_and_process_metadata(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 7 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 7)) + + def test_device_host_key_uses_task_id_without_slice_index(self): + class FakeDevice: + + def __init__(self): + self.task_id = 9 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (None, 9)) + + def test_device_host_key_prefers_task_id_over_process_index(self): + class FakeDevice: + + def __init__(self): + self.slice_index = 4 + self.process_index = 0 + self.task_id = 9 + + self.assertEqual(mesh.device_host_key(FakeDevice()), (4, 9)) + + def test_device_host_key_returns_none_without_task_metadata(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.device_host_key(FakeDevice())) + + def test_device_slice_id_uses_slice_index_only(self): + class SliceIndexDevice: + + def __init__(self): + self.slice_index = 4 + + self.assertEqual(mesh.device_slice_id(SliceIndexDevice()), 4) + self.assertIsNone(mesh.device_slice_id(object())) + + def test_group_devices_by_slice_preserves_first_seen_order(self): + class FakeDevice: + + def __init__(self, device_id, slice_index): + self.id = device_id + self.slice_index = slice_index + + grouped = mesh.group_devices_by_slice([ + FakeDevice(0, 2), + FakeDevice(1, 2), + FakeDevice(2, 1), + FakeDevice(3, 1), + ]) + + self.assertEqual([[device.id for device in group] for group in grouped], [[0, 1], [2, 3]]) + + def test_group_devices_by_slice_treats_missing_metadata_as_one_slice(self): + class FakeDevice: + + def __init__(self, device_id): + self.id = device_id + + grouped = mesh.group_devices_by_slice([ + FakeDevice(0), + FakeDevice(1), + ]) + + self.assertEqual([[device.id for device in group] for group in grouped], [[0, 1]]) + + def test_candidate_uses_whole_chips_requires_all_cores(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + + topology = mesh.get_coord_topology([ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (0, 0, 0), 1), + FakeDevice(2, (1, 0, 0), 0), + FakeDevice(3, (1, 0, 0), 1), + ]) + + self.assertFalse( + mesh.candidate_uses_whole_chips( + topology, + [(0, 0, 0, 0), (1, 0, 0, 0)], + ) + ) + self.assertTrue( + mesh.candidate_uses_whole_chips( + topology, + [(0, 0, 0, 0), (0, 0, 0, 1), (1, 0, 0, 0), (1, 0, 0, 1)], + ) + ) + + def test_candidate_uses_whole_chips_ignores_plain_coords_without_core_axis(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + topology = mesh.get_coord_topology([ + FakeDevice(0, (0, 0, 0)), + FakeDevice(1, (0, 0, 1)), + ]) + + self.assertTrue( + mesh.candidate_uses_whole_chips( + topology, + [(0, 0, 0)], + ) + ) + + def test_satisfies_host_bound_shape_rejects_ragged_coords(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + host_devices = [ + FakeDevice(0, (1, 1, 0)), + FakeDevice(1, (1, 0, 1)), + FakeDevice(2, (0, 1, 1)), + FakeDevice(3, (1, 1, 1)), + ] + + self.assertFalse( + mesh._satisfies_host_bound_shape( + host_devices, + (2, 2, 1), + 4, + ) + ) + + def test_get_coord_topology_builds_bounding_box(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (2, 1, 0)), + FakeDevice(1, (3, 1, 0)), + FakeDevice(2, (2, 2, 0)), + FakeDevice(3, (3, 2, 0)), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertIsNotNone(topology) + self.assertEqual(topology.num_dims, 3) + self.assertEqual(topology.max_shape, (2, 2, 1)) + self.assertEqual(topology.all_coords, ((2, 1, 0), (3, 1, 0), (2, 2, 0), (3, 2, 0))) + + def test_get_coord_topology_rejects_duplicate_coords(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0))] + + self.assertIsNone(mesh.get_coord_topology(fake_devices)) + + def test_get_coord_topology_uses_core_on_chip_to_disambiguate_devices(self): + class FakeDevice: + + def __init__(self, coords, core_on_chip): + self.coords = coords + self.core_on_chip = core_on_chip + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((0, 0, 0), 1), + ] + + topology = mesh.get_coord_topology(fake_devices) + + self.assertIsNotNone(topology) + self.assertEqual(topology.all_coords, ((0, 0, 0, 0), (0, 0, 0, 1))) + self.assertTrue(topology.has_core_on_chip_dimension) + + def test_get_coord_topology_marks_plain_coords_without_core_axis(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + + topology = mesh.get_coord_topology([FakeDevice((0, 0, 0)), FakeDevice((0, 0, 1))]) + + self.assertIsNotNone(topology) + self.assertFalse(topology.has_core_on_chip_dimension) + + def test_get_coord_topology_rejects_empty_device_list(self): + self.assertIsNone(mesh.get_coord_topology([])) + + def test_get_coord_topology_rejects_mismatched_coord_dimensions(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0, 1))] + + self.assertIsNone(mesh.get_coord_topology(fake_devices)) + + def test_summarize_devices_for_logging_includes_id_coords_and_host(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index, slice_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + self.slice_index = slice_index + + self.assertEqual( + mesh.summarize_devices_for_logging([FakeDevice(11, (1, 2, 0), 5, 6)]), + [{"id": 11, "coords": (1, 2, 0), "host": (6, 5)}], + ) + + def test_group_devices_by_host_groups_equal_sized_hosts(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + grouped = mesh.group_devices_by_host([ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + FakeDevice(3, 1), + ]) + + self.assertEqual([[device.id for device in group] for group in grouped], [[0, 1], [2, 3]]) + + def test_group_devices_by_host_sorts_by_slice_then_host_id(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, process_index): + self.id = device_id + self.slice_index = slice_index + self.process_index = process_index + + grouped = mesh.group_devices_by_host([ + FakeDevice(3, 1, 1), + FakeDevice(2, 1, 0), + FakeDevice(1, 0, 1), + FakeDevice(0, 0, 0), + ]) + + self.assertEqual([[device.id for device in group] for group in grouped], [[0], [1], [2], [3]]) + + def test_allocate_named_mesh_device_slices_ignores_host_groups_when_coord_allocation_fails(self): + class FakeDevice: + + def __init__(self, device_id, task_id, coords, core_on_chip): + self.id = device_id + self.process_index = 0 + self.task_id = task_id + self.coords = coords + self.core_on_chip = core_on_chip + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for task_id, z in ((0, 0), (1, 1)): + for x in range(2): + for y in range(2): + for core_on_chip in (0, 1): + fake_devices.append( + FakeDevice(device_id, task_id, (x, y, z), core_on_chip) + ) + device_id += 1 + + with self.assertRaisesRegex( + ValueError, + "coord-based allocation could not construct a valid box", + ): + with mock.patch.object(mesh, "_allocate_devices_by_coords", return_value=(None, None)): + mesh.allocate_named_mesh_device_slices( + [("actor", 8)], + devices=fake_devices, + ) + + def test_group_devices_by_host_returns_none_without_host_metadata(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.group_devices_by_host([FakeDevice()])) + + def test_group_devices_by_host_returns_none_for_inconsistent_host_sizes(self): + class FakeDevice: + + def __init__(self, device_id, process_index): + self.id = device_id + self.process_index = process_index + + self.assertIsNone( + mesh.group_devices_by_host([ + FakeDevice(0, 0), + FakeDevice(1, 0), + FakeDevice(2, 1), + ]) + ) + + def test_divisors_returns_sorted_unique_factors(self): + self.assertEqual(mesh._divisors(12), [1, 2, 3, 4, 6, 12]) + + def test_enumerate_box_shapes_returns_shapes_with_requested_volume(self): + self.assertEqual( + mesh._enumerate_box_shapes(4, (4, 2, 2)), + [(1, 2, 2), (2, 1, 2), (2, 2, 1), (4, 1, 1)], + ) + + def test_coord_box_score_prefers_more_compact_shapes(self): + compact_score = mesh._coord_box_score((0, 0, 0), (2, 2, 1)) + stretched_score = mesh._coord_box_score((0, 0, 0), (1, 4, 1)) + + self.assertLess(compact_score, stretched_score) + + def test_split_coord_region_returns_z_y_x_remainders(self): + region = mesh.CoordRegion((0, 0, 0), (16, 16, 16)) + + self.assertEqual( + mesh._split_coord_region(region, (0, 0, 0), (4, 4, 8)), + ( + mesh.CoordRegion((0, 0, 8), (4, 4, 8)), + mesh.CoordRegion((0, 4, 0), (4, 12, 16)), + mesh.CoordRegion((4, 0, 0), (12, 16, 16)), + ), + ) + + def test_split_coord_region_accounts_for_non_origin_allocated_start(self): + region = mesh.CoordRegion((0, 0, 0), (16, 16, 16)) + + self.assertEqual( + mesh._split_coord_region(region, (4, 4, 4), (4, 4, 8)), + ( + mesh.CoordRegion((4, 4, 0), (4, 4, 4)), + mesh.CoordRegion((4, 4, 12), (4, 4, 4)), + mesh.CoordRegion((4, 0, 0), (4, 4, 16)), + mesh.CoordRegion((4, 8, 0), (4, 8, 16)), + mesh.CoordRegion((0, 0, 0), (4, 16, 16)), + mesh.CoordRegion((8, 0, 0), (8, 16, 16)), + ), + ) + + def test_device_mesh_coords_appends_core_on_chip_when_present(self): + class FakeDevice: + + def __init__(self): + self.coords = (1, 2, 0) + self.core_on_chip = 1 + + self.assertEqual( + mesh.device_mesh_coords(FakeDevice()), + (1, 2, 0, 1), + ) + + def test_device_mesh_coords_returns_none_without_coords(self): + class FakeDevice: + pass + + self.assertIsNone(mesh.device_mesh_coords(FakeDevice())) + + def test_known_host_mesh_shape_returns_none_for_unknown_device_family(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0, 0) + self.device_kind = "unknown" + + self.assertIsNone(mesh.known_host_mesh_shape([FakeDevice()])) + + def test_known_host_mesh_shape_returns_none_when_coord_rank_mismatches_bounds(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0) + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice() for _ in range(128)] + + self.assertIsNone(mesh.known_host_mesh_shape(fake_devices)) + + def test_resolve_per_host_mesh_shape_returns_inferred_shape(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + self.device_kind = "TPU v7" + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((2, 0, 0), 1), + FakeDevice((3, 0, 0), 1), + FakeDevice((2, 1, 0), 1), + FakeDevice((3, 1, 0), 1), + ] + + self.assertEqual(mesh.resolve_per_host_mesh_shape(fake_devices), (2, 2, 1)) + + def test_known_host_mesh_shape_uses_static_topology_metadata(self): + class FakeDevice: + + def __init__(self): + self.coords = (0, 0, 0) + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice() for _ in range(128)] + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (2, 2, 1), + ) + + def test_known_host_mesh_shape_uses_single_host_bounds_for_tpu7x_2(self): + class FakeDevice: + + def __init__(self, coords): + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [FakeDevice((0, 0, 0)), FakeDevice((0, 0, 0))] + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (1, 1, 1), + ) + + def test_known_host_mesh_shape_appends_core_dimension_when_present(self): + class FakeDevice: + + def __init__(self, coords, core_on_chip): + self.coords = coords + self.core_on_chip = core_on_chip + self.device_kind = "TPU v7" + + fake_devices = [] + for x in range(4): + for y in range(4): + for z in range(4): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice((x, y, z), core_on_chip)) + + self.assertEqual( + mesh.known_host_mesh_shape(fake_devices), + (2, 2, 1, 2), + ) + + def test_known_host_mesh_shape_handles_edge_family_2d_coords(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + self.device_kind = "TPU v6e" + + fake_devices = [ + FakeDevice((0, 0), 0), + FakeDevice((1, 0), 0), + FakeDevice((0, 1), 0), + FakeDevice((1, 1), 0), + ] + + self.assertEqual(mesh.known_host_mesh_shape(fake_devices), (2, 2, 1)) + + def test_known_host_mesh_shape_handles_edge_family_3d_coords_with_core_axis(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.coords = coords + self.process_index = process_index + self.core_on_chip = 0 + self.device_kind = "TPU v5 lite" + + fake_devices = [] + for y in range(4): + for x in range(2): + fake_devices.append(FakeDevice((x, y, 0), 0)) + + self.assertEqual(mesh.known_host_mesh_shape(fake_devices), (2, 4, 1, 1)) + + def test_allocate_devices_by_coords_supports_edge_family_3d_coords_with_trailing_singleton(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + self.core_on_chip = 0 + self.device_kind = "TPU v5 lite" + + fake_devices = [] + device_id = 0 + for y in range(4): + for x in range(4): + fake_devices.append(FakeDevice(device_id, (x, y, 0), 0)) + device_id += 1 + + allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 8) + + self.assertEqual([device.id for device in allocated], [0, 1, 4, 5, 8, 9, 12, 13]) + + def test_resolve_per_host_mesh_shape_raises_on_mismatch(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + self.device_kind = "TPU v7" + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (2, 0, 0), 0), + FakeDevice(3, (3, 0, 0), 0), + FakeDevice(4, (0, 0, 1), 1), + FakeDevice(5, (1, 0, 1), 1), + FakeDevice(6, (2, 0, 1), 1), + FakeDevice(7, (3, 0, 1), 1), + ] + + with self.assertRaisesRegex(ValueError, "Observed host devices do not match known host bounds"): + mesh.resolve_per_host_mesh_shape(fake_devices) + + def test_allocate_named_mesh_device_slices_prefers_coord_boxes(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + + fake_devices = [ + FakeDevice(0, (0, 0, 0, 0)), + FakeDevice(1, (0, 0, 0, 1)), + FakeDevice(2, (1, 0, 0, 0)), + FakeDevice(3, (1, 0, 0, 1)), + FakeDevice(4, (2, 0, 0, 0)), + FakeDevice(5, (2, 0, 0, 1)), + FakeDevice(6, (3, 0, 0, 0)), + FakeDevice(7, (3, 0, 0, 1)), + FakeDevice(8, (0, 1, 0, 0)), + FakeDevice(9, (0, 1, 0, 1)), + FakeDevice(10, (1, 1, 0, 0)), + FakeDevice(11, (1, 1, 0, 1)), + FakeDevice(12, (2, 1, 0, 0)), + FakeDevice(13, (2, 1, 0, 1)), + FakeDevice(14, (3, 1, 0, 0)), + FakeDevice(15, (3, 1, 0, 1)), + ] + + allocated = mesh.allocate_named_mesh_device_slices( + [("actor", 8)], + devices=fake_devices, + ) + + self.assertEqual( + [device.id for device in allocated["actor"]], + [0, 1, 8, 9, 2, 3, 10, 11], + ) + + def test_allocate_devices_by_coords_uses_core_on_chip_dimension(self): + class FakeDevice: + + def __init__(self, device_id, coords, core_on_chip): + self.id = device_id + self.coords = coords + self.core_on_chip = core_on_chip + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(4): + for y in range(4): + for z in range(2): + for core_on_chip in (0, 1): + fake_devices.append(FakeDevice(device_id, (x, y, z), core_on_chip)) + device_id += 1 + + allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 8) + + self.assertEqual( + [device.id for device in allocated], + [0, 1, 4, 5, 16, 17, 20, 21], + ) + + def test_allocate_devices_by_coords_returns_none_without_coord_topology(self): + class FakeDevice: + + def __init__(self, process_index): + self.process_index = process_index + + self.assertEqual( + mesh._allocate_devices_by_coords([FakeDevice(0), FakeDevice(0)], 2), + (None, None), + ) + + def test_allocate_devices_by_coords_returns_best_contiguous_box(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (0, 1, 0), 0), + FakeDevice(3, (1, 1, 0), 0), + FakeDevice(4, (2, 0, 0), 1), + FakeDevice(5, (3, 0, 0), 1), + FakeDevice(6, (2, 1, 0), 1), + FakeDevice(7, (3, 1, 0), 1), + ] + + allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 4) + + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3]) + + def test_allocate_devices_by_coords_prefers_more_cubical_supported_shape(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(4): + for y in range(8): + for z in range(16): + fake_devices.append(FakeDevice(device_id, (x, y, z))) + device_id += 1 + + allocated, _ = mesh._allocate_devices_by_coords(fake_devices, 256) + + allocated_coords = [device.coords for device in allocated] + mins = tuple(min(coords[dim] for coords in allocated_coords) for dim in range(3)) + maxs = tuple(max(coords[dim] for coords in allocated_coords) for dim in range(3)) + + self.assertLen(allocated, 256) + self.assertEqual(mins, (0, 0, 0)) + self.assertEqual(maxs, (3, 7, 7)) + + def test_allocate_devices_tracks_remaining_coord_regions(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(16): + for y in range(16): + for z in range(16): + fake_devices.append(FakeDevice(device_id, (x, y, z))) + device_id += 1 + + allocated, next_state = mesh.allocate_devices( + 128, + devices=fake_devices, + return_state=True, + ) + + allocated_coords = [device.coords for device in allocated] + + self.assertEqual( + next_state.remaining_coord_regions_by_slice, + { + None: ( + mesh.CoordRegion((0, 0, 8), (4, 4, 8)), + mesh.CoordRegion((0, 4, 0), (4, 12, 16)), + mesh.CoordRegion((4, 0, 0), (12, 16, 16)), + ) + }, + ) + self.assertEqual( + ( + tuple(min(coords[dim] for coords in allocated_coords) for dim in range(3)), + tuple(max(coords[dim] for coords in allocated_coords) for dim in range(3)), + ), + ((0, 0, 0), (3, 3, 7)), + ) + + def test_allocate_devices_prefers_smallest_remaining_coord_region_first(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(16): + for y in range(16): + for z in range(16): + fake_devices.append(FakeDevice(device_id, (x, y, z))) + device_id += 1 + + _, next_state = mesh.allocate_devices( + 128, + devices=fake_devices, + return_state=True, + ) + allocated = mesh.allocate_devices(128, allocation_state=next_state) + + allocated_coords = [device.coords for device in allocated] + self.assertEqual( + ( + tuple(min(coords[dim] for coords in allocated_coords) for dim in range(3)), + tuple(max(coords[dim] for coords in allocated_coords) for dim in range(3)), + ), + ((0, 0, 8), (3, 3, 15)), + ) + + def test_allocate_devices_matches_required_count_to_smallest_fitting_remaining_region(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(16): + for y in range(16): + for z in range(16): + fake_devices.append(FakeDevice(device_id, (x, y, z))) + device_id += 1 + + _, next_state = mesh.allocate_devices( + 128, + devices=fake_devices, + return_state=True, + ) + allocated = mesh.allocate_devices(576, allocation_state=next_state) + + allocated_coords = [device.coords for device in allocated] + self.assertEqual( + ( + tuple(min(coords[dim] for coords in allocated_coords) for dim in range(3)), + tuple(max(coords[dim] for coords in allocated_coords) for dim in range(3)), + ), + ((0, 4, 0), (3, 15, 11)), + ) + + def test_allocate_devices_performance_policy_prefers_more_cubical_shape(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(16): + for y in range(16): + for z in range(16): + fake_devices.append(FakeDevice(device_id, (x, y, z))) + device_id += 1 + + _, next_state = mesh.allocate_devices( + 128, + devices=fake_devices, + allocation_policy="PERFORMANCE", + return_state=True, + ) + allocated = mesh.allocate_devices(512, allocation_state=next_state) + + allocated_coords = [device.coords for device in allocated] + self.assertEqual( + ( + tuple(min(coords[dim] for coords in allocated_coords) for dim in range(3)), + tuple(max(coords[dim] for coords in allocated_coords) for dim in range(3)), + ), + ((4, 0, 0), (11, 7, 7)), + ) + + def test_allocate_devices_compact_policy_prefers_smallest_fitting_region(self): + class FakeDevice: + + def __init__(self, device_id, coords): + self.id = device_id + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [] + device_id = 0 + for x in range(16): + for y in range(16): + for z in range(16): + fake_devices.append(FakeDevice(device_id, (x, y, z))) + device_id += 1 + + _, next_state = mesh.allocate_devices( + 128, + devices=fake_devices, + allocation_policy="COMPACT", + return_state=True, + ) + allocated = mesh.allocate_devices(512, allocation_state=next_state) + + allocated_coords = [device.coords for device in allocated] + self.assertEqual( + ( + tuple(min(coords[dim] for coords in allocated_coords) for dim in range(3)), + tuple(max(coords[dim] for coords in allocated_coords) for dim in range(3)), + ), + ((0, 4, 0), (3, 11, 15)), + ) + + def test_allocate_devices_rejects_mismatched_policy_for_existing_state(self): + fake_devices = [object(), object()] + + with mock.patch.object(mesh, "_allocate_devices_by_coords", return_value=([fake_devices[0]], None)): + _, state = mesh.allocate_devices( + 1, + devices=fake_devices, + allocation_policy="COMPACT", + return_state=True, + ) + + with self.assertRaisesRegex( + ValueError, + "allocation_policy must match allocation_state.allocation_policy", + ): + mesh.allocate_devices( + 1, + allocation_state=state, + allocation_policy="PERFORMANCE", + ) + + def test_allocate_devices_allocates_single_mesh(self): + fake_devices = [object(), object()] + + with mock.patch.object( + mesh, + "_allocate_devices_by_coords", + return_value=(fake_devices, None), + ) as allocate_mock: + allocated = mesh.allocate_devices(2, devices=fake_devices) + + allocate_mock.assert_called_once_with( + fake_devices, + 2, + None, + "COMPACT", + ) + self.assertIs(allocated, fake_devices) + + def test_allocate_devices_prefers_coords_over_host_groups(self): + class FakeDevice: + + def __init__(self, device_id, coords, process_index): + self.id = device_id + self.coords = coords + self.process_index = process_index + self.device_kind = "TPU v7" + + fake_devices = [ + FakeDevice(0, (0, 0, 0), 0), + FakeDevice(1, (1, 0, 0), 0), + FakeDevice(2, (0, 1, 0), 0), + FakeDevice(3, (1, 1, 0), 0), + FakeDevice(4, (2, 0, 0), 1), + FakeDevice(5, (3, 0, 0), 1), + FakeDevice(6, (2, 1, 0), 1), + FakeDevice(7, (3, 1, 0), 1), + ] + + allocated = mesh.allocate_devices(4, devices=fake_devices) + + self.assertEqual([device.id for device in allocated], [0, 2, 1, 3]) + + def test_allocate_devices_returns_updated_state_for_incremental_use(self): + fake_devices = [object(), object(), object()] + + with mock.patch.object( + mesh, + "_allocate_devices_by_coords", + side_effect=[(fake_devices[:1], None), (fake_devices[1:], None)], + ): + assigned_devices, next_state = mesh.allocate_devices( + 1, + devices=fake_devices, + return_state=True, + ) + remaining_devices = list(next_state.remaining_devices) + assigned_devices_2, final_state = mesh.allocate_devices( + 2, + allocation_state=next_state, + return_state=True, + ) + + self.assertEqual(assigned_devices, fake_devices[:1]) + self.assertEqual(remaining_devices, fake_devices[1:]) + self.assertEqual(assigned_devices_2, fake_devices[1:]) + self.assertEqual(list(final_state.remaining_devices), []) + self.assertEqual(final_state.used_device_count, 3) + + def test_allocate_devices_raises_when_incremental_state_is_exhausted(self): + allocation_state = mesh.DeviceAllocationState( + remaining_devices=(), + full_devices_per_host=0, + host_bound_shape=None, + total_device_count=0, + ) + + with self.assertRaisesRegex(ValueError, "but only 0 remain available"): + mesh.allocate_devices(1, allocation_state=allocation_state) + + def test_allocate_devices_rejects_devices_and_state_together(self): + fake_devices = [object()] + allocation_state = mesh.DeviceAllocationState( + remaining_devices=tuple(fake_devices), + full_devices_per_host=0, + host_bound_shape=None, + total_device_count=1, + ) + + with self.assertRaisesRegex( + ValueError, + "Pass either devices or allocation_state to allocate_devices, not both", + ): + mesh.allocate_devices( + 1, + devices=fake_devices, + allocation_state=allocation_state, + ) + + def test_allocate_devices_prefers_single_slice_before_cross_slice(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, coords): + self.id = device_id + self.slice_index = slice_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 0, (2, 0, 0)), + FakeDevice(3, 0, (3, 0, 0)), + FakeDevice(4, 1, (4, 0, 0)), + FakeDevice(5, 1, (5, 0, 0)), + FakeDevice(6, 1, (6, 0, 0)), + FakeDevice(7, 1, (7, 0, 0)), + ] + + allocated = mesh.allocate_devices(4, devices=fake_devices) + + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3]) + + def test_allocate_devices_prefers_single_slice_before_other_slices(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, process_index, coords): + self.id = device_id + self.slice_index = slice_index + self.process_index = process_index + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [ + FakeDevice(0, 0, 0, (0, 0, 0)), + FakeDevice(1, 0, 0, (1, 0, 0)), + FakeDevice(2, 0, 0, (0, 1, 0)), + FakeDevice(3, 0, 0, (1, 1, 0)), + FakeDevice(4, 1, 1, (2, 0, 0)), + FakeDevice(5, 1, 1, (3, 0, 0)), + FakeDevice(6, 1, 1, (2, 1, 0)), + FakeDevice(7, 1, 1, (3, 1, 0)), + ] + + allocated = mesh.allocate_devices(2, devices=fake_devices) + + self.assertEqual([device.id for device in allocated], [0, 1]) + + def test_allocate_devices_raises_when_cross_slice_request_needs_partial_slice(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, coords): + self.id = device_id + self.slice_index = slice_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (2, 0, 0)), + FakeDevice(2, 0, (4, 0, 0)), + FakeDevice(3, 0, (6, 0, 0)), + FakeDevice(4, 1, (1, 0, 0)), + FakeDevice(5, 1, (3, 0, 0)), + FakeDevice(6, 1, (5, 0, 0)), + FakeDevice(7, 1, (7, 0, 0)), + ] + + with self.assertRaisesRegex( + ValueError, + "cross-slice allocation only supports whole slices", + ): + mesh.allocate_devices(6, devices=fake_devices) + + def test_allocate_devices_consumes_whole_slices_in_order_when_cross_slice(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, coords): + self.id = device_id + self.slice_index = slice_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (2, 0, 0)), + FakeDevice(2, 0, (4, 0, 0)), + FakeDevice(3, 0, (6, 0, 0)), + FakeDevice(4, 1, (1, 0, 0)), + FakeDevice(5, 1, (3, 0, 0)), + FakeDevice(6, 1, (5, 0, 0)), + FakeDevice(7, 1, (7, 0, 0)), + ] + + allocated = mesh.allocate_devices(8, devices=fake_devices) + + self.assertEqual([device.id for device in allocated], [0, 1, 2, 3, 4, 5, 6, 7]) + + def test_allocate_devices_skips_partial_slice_for_cross_slice_request(self): + class FakeDevice: + + def __init__(self, device_id, slice_index, coords): + self.id = device_id + self.slice_index = slice_index + self.coords = coords + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 0, (2, 0, 0)), + FakeDevice(3, 0, (3, 0, 0)), + FakeDevice(4, 1, (4, 0, 0)), + FakeDevice(5, 1, (5, 0, 0)), + FakeDevice(6, 1, (6, 0, 0)), + FakeDevice(7, 1, (7, 0, 0)), + FakeDevice(8, 2, (8, 0, 0)), + FakeDevice(9, 2, (9, 0, 0)), + FakeDevice(10, 2, (10, 0, 0)), + FakeDevice(11, 2, (11, 0, 0)), + ] + + assigned_devices, next_state = mesh.allocate_devices( + 2, + devices=fake_devices, + return_state=True, + ) + self.assertEqual([device.id for device in assigned_devices], [0, 1]) + + allocated = mesh.allocate_devices(8, allocation_state=next_state) + + self.assertEqual([device.id for device in allocated], [4, 5, 6, 7, 8, 9, 10, 11]) + + def test_allocate_named_mesh_device_slices_calls_allocate_devices_in_loop(self): + fake_devices = [object(), object(), object()] + state_0 = mesh.DeviceAllocationState( + remaining_devices=tuple(fake_devices), + full_devices_per_host=0, + host_bound_shape=None, + total_device_count=3, + used_device_count=0, + ) + state_1 = mesh.DeviceAllocationState( + remaining_devices=tuple(fake_devices[1:]), + full_devices_per_host=0, + host_bound_shape=None, + total_device_count=3, + used_device_count=1, + ) + state_2 = mesh.DeviceAllocationState( + remaining_devices=(), + full_devices_per_host=0, + host_bound_shape=None, + total_device_count=3, + used_device_count=3, + ) + + with mock.patch.object( + mesh, + "allocate_devices", + side_effect=[ + ([fake_devices[0]], state_1), + ([fake_devices[1], fake_devices[2]], state_2), + ], + ) as allocate_mock, mock.patch.object( + mesh, + "_create_device_allocation_state", + return_value=state_0, + ) as state_mock, mock.patch.object( + mesh.logging, + "warning", + ) as warning_mock: + allocated = mesh.allocate_named_mesh_device_slices( + [("mesh1", 1), ("mesh2", 2)], + devices=fake_devices, + ) + + state_mock.assert_called_once_with( + fake_devices, + allocation_policy="COMPACT", + ) + self.assertEqual(allocate_mock.call_count, 2) + self.assertEqual( + allocate_mock.call_args_list, + [ + mock.call( + 1, + mesh_name="mesh1", + allocation_state=state_0, + return_state=True, + ), + mock.call( + 2, + mesh_name="mesh2", + allocation_state=state_1, + return_state=True, + ), + ], + ) + warning_mock.assert_not_called() + self.assertEqual( + allocated, + {"mesh1": [fake_devices[0]], "mesh2": [fake_devices[1], fake_devices[2]]}, + ) + + @mock.patch.object(jax, "device_count") + def test_create_mesh_uses_jax_make_mesh_without_assigned_devices( + self, mock_device_count_fn + ): + mock_device_count_fn.return_value = 4 + expected_mesh = object() + + with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock: + created_mesh = mesh.create_mesh((2, 2), ("x", "y")) + + make_mesh_mock.assert_called_once_with( + (2, 2), + ("x", "y"), + axis_types=(jax.sharding.AxisType.Auto,) * 2, + ) + self.assertIs(created_mesh, expected_mesh) + + def test_create_mesh_uses_assigned_devices(self): + assigned_devices = ["d0", "d1", "d2", "d3"] + + class FakeMesh: + + def __init__(self, devices, axis_names, axis_types=None): + self.devices = devices + self.axis_names = axis_names + self.axis_types = axis_types + + with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): + created_mesh = mesh.create_mesh( + (2, 2), + ("x", "y"), + devices=assigned_devices, + ) + + self.assertEqual(created_mesh.devices.shape, (2, 2)) + self.assertEqual( + created_mesh.devices.flatten().tolist(), + assigned_devices, + ) + self.assertEqual(created_mesh.axis_names, ("x", "y")) + + def test_allocate_named_mesh_device_slices_uses_jax_devices_by_default(self): + class FakeDevice: + + def __init__(self, device_id): + self.id = device_id + self.coords = (device_id, 0, 0) + + fake_devices = [FakeDevice(0), FakeDevice(1)] + + with mock.patch.object(mesh.jax, "devices", return_value=fake_devices): + allocated = mesh.allocate_named_mesh_device_slices([("trainer", 2)]) + + self.assertEqual([device.id for device in allocated["trainer"]], [0, 1]) + + def test_allocate_named_mesh_device_slices_raises_when_coord_allocation_fails(self): + class FakeDevice: + + def __init__(self, device_id, process_index, coords): + self.id = device_id + self.process_index = process_index + self.coords = coords + self.device_kind = "TPU v7" + + fake_devices = [ + FakeDevice(0, 0, (0, 0, 0)), + FakeDevice(1, 0, (1, 0, 0)), + FakeDevice(2, 0, (0, 1, 0)), + FakeDevice(3, 0, (1, 1, 0)), + FakeDevice(4, 1, (0, 0, 1)), + FakeDevice(5, 1, (1, 0, 1)), + FakeDevice(6, 1, (0, 1, 1)), + FakeDevice(7, 1, (1, 1, 1)), + ] + + with self.assertRaisesRegex( + ValueError, + "coord-based allocation could not construct a valid box", + ): + with mock.patch.object(mesh, "_allocate_devices_by_coords", return_value=(None, None)): + mesh.allocate_named_mesh_device_slices( + [("trainer", 4), ("rollout", 4)], + devices=fake_devices, + ) + + def test_allocate_named_mesh_device_slices_raises_when_not_enough_devices(self): + class FakeDevice: + + def __init__(self, device_id): + self.id = device_id + + fake_devices = [FakeDevice(0), FakeDevice(1)] + + with self.assertRaisesRegex(ValueError, "but only 2 remain available"): + mesh.allocate_named_mesh_device_slices( + [("trainer", 3)], + devices=fake_devices, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/utils/topology_test.py b/tests/utils/topology_test.py new file mode 100644 index 000000000..f28645f82 --- /dev/null +++ b/tests/utils/topology_test.py @@ -0,0 +1,162 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from tunix.utils import topology + + +class TopologyTest(absltest.TestCase): + + def test_normalize_device_kind_recognizes_supported_families(self): + self.assertEqual(topology._normalize_device_kind("TPU v7"), "tpu7x") + self.assertEqual(topology._normalize_device_kind("TPU v6e"), "v6e") + self.assertEqual(topology._normalize_device_kind("TPU v6 lite"), "v6e") + self.assertEqual(topology._normalize_device_kind("TPU v5e"), "v5e") + self.assertEqual(topology._normalize_device_kind("TPU v5 lite"), "v5e") + self.assertEqual(topology._normalize_device_kind("TPU v5p"), "v5p") + self.assertEqual(topology._normalize_device_kind("TPU v4"), "v4") + self.assertIsNone(topology._normalize_device_kind("gpu")) + + def test_infer_chips_per_host_bounds_returns_none_for_empty_devices(self): + self.assertIsNone(topology.infer_chips_per_host_bounds([])) + + def test_infer_chips_per_host_bounds_returns_none_for_missing_device_kind(self): + class FakeDevice: + pass + + self.assertIsNone(topology.infer_chips_per_host_bounds([FakeDevice()])) + + def test_infer_chips_per_host_bounds_uses_single_host_shapes(self): + class FakeDevice: + + def __init__(self, device_kind, coords=None): + self.device_kind = device_kind + self.coords = coords + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v5e", (0, 0, 0))]), + (1, 1, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v6e", (0, 0))]), + (1, 1, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v6e", (0, 0, 0))]), + (1, 1, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v7"), FakeDevice("TPU v7")]), + (1, 1, 1), + ) + + def test_infer_chips_per_host_bounds_uses_multi_host_shape_otherwise(self): + class FakeDevice: + + def __init__(self, device_kind, coords=None): + self.device_kind = device_kind + self.coords = coords + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v7") for _ in range(4)]), + (2, 2, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v4") for _ in range(8)]), + (2, 2, 1), + ) + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice("TPU v5e", (0, 0)), FakeDevice("TPU v5e", (0, 1))]), + (2, 2, 1), + ) + + def test_infer_chips_per_host_bounds_prefers_runtime_host_shape(self): + class FakeDevice: + + def __init__(self, coords, process_index): + self.device_kind = "TPU v5 lite" + self.coords = coords + self.process_index = process_index + + fake_devices = [ + FakeDevice((0, 0, 0), 0), + FakeDevice((1, 0, 0), 0), + FakeDevice((0, 1, 0), 0), + FakeDevice((1, 1, 0), 0), + FakeDevice((0, 2, 0), 0), + FakeDevice((1, 2, 0), 0), + FakeDevice((0, 3, 0), 0), + FakeDevice((1, 3, 0), 0), + ] + + self.assertEqual( + topology.infer_chips_per_host_bounds(fake_devices), + (2, 4, 1), + ) + + def test_infer_chips_per_host_bounds_handles_callable_device_kind(self): + class FakeDevice: + + def device_kind(self): + return "TPU v7" + + self.assertEqual( + topology.infer_chips_per_host_bounds([FakeDevice() for _ in range(128)]), + (2, 2, 1), + ) + + def test_best_topology_shapes_for_chip_count_returns_unique_edge_shape(self): + self.assertEqual( + topology.best_topology_shapes_for_chip_count("TPU v6e", 8, chip_rank=2), + [(2, 4)], + ) + self.assertEqual( + topology.best_topology_shapes_for_chip_count("TPU v6e", 8, chip_rank=3), + [(2, 4, 1)], + ) + + def test_best_topology_shapes_for_chip_count_prefers_most_cubical_fish_shape(self): + self.assertEqual( + topology.best_topology_shapes_for_chip_count("TPU v7", 256), + [(4, 8, 8)], + ) + + def test_best_topology_shapes_for_chip_count_filters_by_available_shape(self): + self.assertEqual( + topology.best_topology_shapes_for_chip_count( + "TPU v6e", + 8, + chip_rank=2, + available_chip_shape=(1, 8), + ), + [], + ) + + def test_best_topology_shapes_for_chip_count_derives_shape_within_remaining_region(self): + self.assertEqual( + topology.best_topology_shapes_for_chip_count( + "TPU v7", + 576, + available_chip_shape=(4, 12, 16), + ), + [(4, 12, 12)], + ) + + def test_best_topology_shapes_for_chip_count_rejects_non_cube_multiple(self): + with self.assertRaisesRegex(ValueError, "must be divisible by 64 chips"): + topology.best_topology_shapes_for_chip_count("TPU v7", 96) + + +if __name__ == "__main__": + absltest.main() diff --git a/tunix/cli/base_agentic_config.yaml b/tunix/cli/base_agentic_config.yaml index 6027919de..94616daee 100644 --- a/tunix/cli/base_agentic_config.yaml +++ b/tunix/cli/base_agentic_config.yaml @@ -91,6 +91,12 @@ model_config: &base_model_config shape: "(2,2)" # "('fsdp',)" axis_names: "('fsdp','tp')" + # Optional device allocation policy used when this mesh is carved from a + # larger device pool. Supported values: + # COMPACT: pack into the smallest fitting remaining region. + # PERFORMANCE: prefer more cubical supported extracted shapes. + # When omitted, Tunix defaults to COMPACT. + # allocation_policy: "COMPACT" actor_model_config: <<: *base_model_config diff --git a/tunix/cli/base_config.yaml b/tunix/cli/base_config.yaml index 450ee9d85..648efe571 100644 --- a/tunix/cli/base_config.yaml +++ b/tunix/cli/base_config.yaml @@ -91,6 +91,12 @@ model_config: &base_model_config shape: "(2,2)" # "('fsdp',)" axis_names: "('fsdp','tp')" + # Optional device allocation policy used when this mesh is carved from a + # larger device pool. Supported values: + # COMPACT: pack into the smallest fitting remaining region. + # PERFORMANCE: prefer more cubical supported extracted shapes. + # When omitted, Tunix defaults to COMPACT. + # allocation_policy: "COMPACT" actor_model_config: <<: *base_model_config diff --git a/tunix/cli/config.py b/tunix/cli/config.py index c4c0b417c..56ef849e4 100644 --- a/tunix/cli/config.py +++ b/tunix/cli/config.py @@ -36,6 +36,7 @@ from tunix.perf import metrics as perf_metrics from tunix.sft import metrics_logger from tunix.sft import profiler +from tunix.utils import mesh as mesh_lib # Define a prefix for environment variables that can override YAML keys _TUNIX_PREFIX = "T_" @@ -633,8 +634,21 @@ def create_optimizer( f"Check if the arguments match the signature of optax.{opt_type}: {e}" ) from e - def _parse_mesh_config(self, model_key: str) -> tuple[tuple[int, ...], tuple[str, ...]]: - """Validate and parse mesh configuration for a model key.""" + def parse_mesh_config(self, model_key: str) -> tuple[tuple[int, ...], tuple[str, ...]]: + """Validates and parses the mesh shape and axis names for one model. + + Args: + model_key: Config section name such as ``model_config`` or + ``actor_model_config``. + + Returns: + A tuple ``(axis_shapes, axis_names)`` ready to pass to + ``tunix.utils.mesh.create_mesh``. + + Raises: + ValueError: If the mesh config is missing, malformed, or internally + inconsistent. + """ mesh_config = self.config[model_key].get("mesh") if not mesh_config: @@ -693,49 +707,41 @@ def _parse_mesh_config(self, model_key: str) -> tuple[tuple[int, ...], tuple[str ) return tuple(axis_shapes), tuple(axis_names) - def create_mesh(self, model_key: str, devices: Sequence[Any] | None = None): - """Validate and extract mesh configuration from a dictionary. + def _parse_mesh_allocation_policy(self, model_key: str) -> str: + """Validates and returns the mesh allocation policy for one model. + + Mesh allocation policy controls how Tunix chooses device subsets when a + mesh must be carved out of a larger device pool. + + Supported values are: + + * ``COMPACT``: prefer the smallest remaining region that can satisfy the + request. + * ``PERFORMANCE``: prefer the most cubical supported extracted shape. - Expects raw_keys to contain a 'mesh' key, which is a dictionary with 'shape' - and 'axis_names' keys. + When ``mesh.allocation_policy`` is omitted, this defaults to ``COMPACT``. Args: - model_key: A model key that contain raw mesh configuration. For example, - in rl, there are actor_model, critic_model and reference_model, each of - them could have different mesh configuration. - devices: Optional explicit device subset to use for the mesh. When - provided, the mesh shape must exactly match the number of assigned - devices. + model_key: Config section name such as ``model_config`` or + ``actor_model_config``. Returns: - A tuple containing (axis_shapes, axis_names), both as tuples. + The normalized allocation policy string. Raises: - ValueError: If the mesh configuration is missing, malformed, or invalid. + ValueError: If the mesh config is missing or the policy value is not + supported. """ - - axis_shapes, axis_names = self._parse_mesh_config(model_key) - num_devices = len(devices) if devices is not None else jax.device_count() - if np.prod(axis_shapes) > num_devices: + mesh_config = self.config[model_key].get("mesh") + if not mesh_config: + raise ValueError("Missing 'mesh' configuration in raw_keys.") + if not isinstance(mesh_config, collections.abc.Mapping): raise ValueError( - f"Mesh shape {axis_shapes} requires {np.prod(axis_shapes)} devices, " - f"but found {num_devices}." - ) - if devices is not None: - if np.prod(axis_shapes) != num_devices: - raise ValueError( - f"Mesh shape {axis_shapes} requires {np.prod(axis_shapes)} devices, " - f"but was assigned {num_devices}." - ) - return jax.sharding.Mesh( - np.array(list(devices)).reshape(axis_shapes), - axis_names, - axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), + "The 'mesh' configuration must be a dictionary-like object, got" + f" {type(mesh_config)}." ) - return jax.make_mesh( - axis_shapes, - axis_names, - axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), + return mesh_lib.normalize_allocation_policy( + mesh_config.get("allocation_policy") ) def obtain_training_config_dict(self, key): diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 6c651c5e7..794491aea 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -52,6 +52,7 @@ from tunix.perf.experimental import export as perf_export_v2 from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout +from tunix.utils import mesh as mesh_lib _PATHWAYS_BNS = flags.DEFINE_string( @@ -189,7 +190,25 @@ def resolve_owner( role_to_owner[role] = resolve_owner(role, set()) return role_to_owner - def _create_role_to_mesh(self): + def create_role_to_mesh(self): + """Builds the role-to-mesh mapping for GRPO execution. + + Any role with an explicit ``*.mesh`` config gets a dedicated device slice. + Roles without a mesh share the actor mesh by default, or can point at + another role via ``same_mesh_as``. + + All mesh owners participating in the same allocation pass must agree on + one ``mesh.allocation_policy`` value. That policy is then passed to the + mesh allocator so users can choose between compact packing and + performance-oriented cubical packing from config. + + Returns: + A mapping from logical GRPO role to the concrete JAX mesh it should use. + + Raises: + ValueError: If mesh ownership resolution is invalid or if mesh owners + request conflicting allocation policies. + """ devices = list(jax.devices()) role_to_owner = self._resolve_mesh_owners() owner_order = [] @@ -200,50 +219,36 @@ def _create_role_to_mesh(self): if owner not in owner_order: owner_order.append(owner) - owner_to_mesh = {} - owner_to_device_slice = {} - device_offset = 0 + mesh_requirements = [] + allocation_policy = None for owner in owner_order: model_key = self._ROLE_TO_MODEL_KEY[owner] - axis_shapes, _ = self._parse_mesh_config(model_key) - required_devices = int(np.prod(axis_shapes)) - next_offset = device_offset + required_devices - if next_offset > len(devices): + axis_shapes, _ = self.parse_mesh_config(model_key) + owner_policy = self._parse_mesh_allocation_policy(model_key) + if allocation_policy is None: + allocation_policy = owner_policy + elif owner_policy != allocation_policy: raise ValueError( - f"Mesh allocation requires {next_offset} devices after allocating" - f" {model_key}, but only {len(devices)} are available." + "All owned meshes must use the same mesh.allocation_policy, got " + f"{allocation_policy!r} and {owner_policy!r}." ) - assigned_devices = devices[device_offset:next_offset] - owner_to_device_slice[owner] = assigned_devices - owner_to_mesh[owner] = self.create_mesh( - model_key, devices=assigned_devices - ) - device_offset = next_offset + mesh_requirements.append((model_key, int(np.prod(axis_shapes)))) - if device_offset < len(devices): - logging.warning( - "Mesh allocation used %d of %d devices; %d devices remain unused.", - device_offset, - len(devices), - len(devices) - device_offset, - ) - logging.info( - "Mesh device allocation: %s", - { - self._ROLE_TO_MODEL_KEY[owner]: len(owner_to_device_slice[owner]) - for owner in owner_order - }, + allocated_devices = mesh_lib.allocate_named_mesh_device_slices( + mesh_requirements, + devices=devices, + allocation_policy=allocation_policy or mesh_lib.normalize_allocation_policy(None), ) - return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} - - def create_role_to_mesh(self): - """Build role→mesh mapping. - Any role with an explicit ``*.mesh`` config gets a dedicated device slice. - Roles without a mesh share the actor mesh by default, or can point at - another role via ``same_mesh_as``. - """ - return self._create_role_to_mesh() + owner_to_mesh = {} + for owner in owner_order: + model_key = self._ROLE_TO_MODEL_KEY[owner] + axis_shapes, axis_names = self.parse_mesh_config(model_key) + assigned_devices = allocated_devices[model_key] + owner_to_mesh[owner] = mesh_lib.create_mesh( + axis_shapes, axis_names, devices=assigned_devices + ) + return {role: owner_to_mesh[owner] for role, owner in role_to_owner.items()} # ------------------------------------------------------------------ # Rollout config diff --git a/tunix/cli/peft_main.py b/tunix/cli/peft_main.py index e3724a7a4..652db3cb4 100644 --- a/tunix/cli/peft_main.py +++ b/tunix/cli/peft_main.py @@ -25,6 +25,7 @@ from tunix.examples.data import translation_dataset as data_lib from tunix.sft import peft_trainer from tunix.sft import utils +from tunix.utils import mesh as mesh_lib _PATHWAYS_BNS = flags.DEFINE_string( "pathways_bns", None, "BNS address of the Pathways server." @@ -36,7 +37,8 @@ class PeftPipeline(config.HyperParameters): def run_peft_trainer(self): """Run the PEFT trainer.""" - mesh: jax.sharding.Mesh = self.create_mesh('model_config') + axis_shapes, axis_names = self.parse_mesh_config('model_config') + mesh: jax.sharding.Mesh = mesh_lib.create_mesh(axis_shapes, axis_names) model: nnx.Module | None = None tokenizer: Any | None = None my_gen_model_input_fn: ( diff --git a/tunix/utils/mesh.py b/tunix/utils/mesh.py new file mode 100644 index 000000000..f511db29e --- /dev/null +++ b/tunix/utils/mesh.py @@ -0,0 +1,1544 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared mesh device allocation helpers. + +Typical usage: + + allocations = allocate_named_mesh_device_slices([ + ("actor", 8), + ("rollout", 4), + ], allocation_policy="COMPACT") + +The keys are arbitrary mesh names chosen by the caller. The integer is the +number of devices that mesh should receive. + +Supported allocation policies: + +* ``COMPACT``: pack meshes into the smallest fitting remaining coord region. +* ``PERFORMANCE``: prefer more cubical supported extracted shapes. +""" + +import collections +import dataclasses +from typing import Any, Sequence + +from absl import logging +import jax +import numpy as np +from tunix.utils import topology + +MeshRequirement = tuple[str, int] +_COMPACT_ALLOCATION_POLICY = "COMPACT" +_PERFORMANCE_ALLOCATION_POLICY = "PERFORMANCE" +_SUPPORTED_ALLOCATION_POLICIES = { + _COMPACT_ALLOCATION_POLICY, + _PERFORMANCE_ALLOCATION_POLICY, +} + + +def create_mesh( + axis_shapes: tuple[int, ...], + axis_names: tuple[str, ...], + devices: Sequence[Any] | None = None, +): + """Builds a JAX mesh from parsed axis metadata. + + Args: + axis_shapes: Mesh dimension sizes such as ``(2, 4)``. + axis_names: Mesh axis names such as ``("data", "model")``. + devices: Optional explicit device assignment. When omitted, Tunix uses the + default JAX device set. + + Returns: + A ``jax.sharding.Mesh`` with the requested logical shape. + + Raises: + ValueError: If the axis metadata is inconsistent or the requested mesh + shape does not match the available device count. + """ + if len(axis_shapes) != len(axis_names): + raise ValueError( + f"mesh.shape {axis_shapes} and mesh.axis_names {axis_names} " + "must have the same length." + ) + + num_devices = len(devices) if devices is not None else jax.device_count() + required_devices = int(np.prod(axis_shapes)) + if required_devices > num_devices: + raise ValueError( + f"Mesh shape {axis_shapes} requires {required_devices} devices, " + f"but found {num_devices}." + ) + if devices is not None: + if required_devices != num_devices: + raise ValueError( + f"Mesh shape {axis_shapes} requires {required_devices} devices, " + f"but was assigned {num_devices}." + ) + return jax.sharding.Mesh( + np.array(list(devices)).reshape(axis_shapes), + axis_names, + axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), + ) + return jax.make_mesh( + axis_shapes, + axis_names, + axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names), + ) + + +@dataclasses.dataclass(frozen=True) +class CoordTopology: + """Normalized coord metadata for a device pool. + + Attributes: + coord_to_device: Mapping from physical coords to device objects. + all_coords: Normalized coord tuples for all devices. + num_dims: Number of coord dimensions. + min_coords: Minimum coord on each axis for the device pool. + max_shape: Bounding-box shape of the device pool. + """ + + coord_to_device: dict[tuple[int, ...], Any] + all_coords: tuple[tuple[int, ...], ...] + num_dims: int + min_coords: tuple[int, ...] + max_shape: tuple[int, ...] + chip_coord_to_coords: dict[tuple[int, ...], tuple[tuple[int, ...], ...]] + has_core_on_chip_dimension: bool + + +@dataclasses.dataclass(frozen=True) +class CoordRegion: + """Axis-aligned coord region tracked across sequential allocations.""" + + start: tuple[int, ...] + shape: tuple[int, ...] + + +@dataclasses.dataclass(frozen=True) +class DeviceAllocationState: + """Tracks the remaining device pool across sequential mesh allocations. + + This state object exists so `allocate_devices()` can be the lowest-level + public API while still supporting multi-mesh allocation. Callers that only + need one mesh can pass `devices=` directly to `allocate_devices()`. Callers + that need multiple meshes can create state once and repeatedly allocate from + it, which is exactly what `allocate_named_mesh_device_slices()` does. + + Attributes: + remaining_devices: Flat view of devices that have not yet been assigned. + full_devices_per_host: Original per-host capacity derived from host groups. + host_bound_shape: Per-host physical topology shape, such as `(2, 2, 1)`. + total_device_count: Size of the original device pool. + remaining_coord_regions_by_slice: Remaining coord regions keyed by slice id. + When slice metadata is absent, the key is `None`. + full_devices_per_slice: Original per-slice capacity keyed by slice id. This + lets later allocations tell whether a remaining slice is still whole or + has already been partially consumed. + used_device_count: Number of devices already assigned. + """ + + remaining_devices: tuple[Any, ...] + full_devices_per_host: int + host_bound_shape: tuple[int, ...] | None + total_device_count: int + allocation_policy: str = _COMPACT_ALLOCATION_POLICY + remaining_coord_regions_by_slice: dict[Any, tuple[CoordRegion, ...]] | None = None + full_devices_per_slice: dict[Any, int] | None = None + used_device_count: int = 0 + + +def _normalize_allocation_policy(allocation_policy: str | None) -> str: + """Validates and normalizes the requested allocation policy.""" + normalized_policy = ( + _COMPACT_ALLOCATION_POLICY + if allocation_policy is None + else allocation_policy.upper() + ) + if normalized_policy not in _SUPPORTED_ALLOCATION_POLICIES: + raise ValueError( + "allocation_policy must be one of " + f"{sorted(_SUPPORTED_ALLOCATION_POLICIES)}, got {allocation_policy!r}." + ) + return normalized_policy + + +def normalize_allocation_policy(allocation_policy: str | None) -> str: + """Validates and normalizes a user-facing mesh allocation policy. + + Args: + allocation_policy: Policy name provided by a caller or config. ``None`` + selects the default policy. + + Returns: + The normalized policy string. + + Raises: + ValueError: If the policy is unsupported. + """ + return _normalize_allocation_policy(allocation_policy) + + +def device_attr(device: Any, attr_name: str) -> Any: + """Returns a raw device attribute, calling it first if JAX exposes it lazily. + + Args: + device: A JAX device or test double. + attr_name: Attribute name such as "coords" or "process_index". + + Returns: + The attribute value, or None if the attribute does not exist. + """ + return topology._device_attr(device, attr_name) + + +def device_host_key(device: Any) -> tuple[Any, ...] | None: + """Returns a stable host grouping key for topology-aware allocation. + + Args: + device: A JAX device or test double. + + Returns: + A tuple of (slice_id, task_id) when that metadata is available, otherwise + None. + """ + return topology._device_host_key(device) + + +def device_slice_id(device: Any) -> Any: + """Returns the slice identifier when the runtime exposes one. + + This is intentionally narrower than `device_host_key()`: it captures only the + slice boundary, not the host/task within that slice. Slice-aware allocation + uses this to prefer satisfying a mesh from one slice before spilling into the + next slice. + """ + return device_attr(device, "slice_index") + + +def device_mesh_coords(device: Any) -> tuple[int, ...] | None: + """Returns physical mesh coordinates for topology-aware allocation. + + Args: + device: A JAX device or test double. + + Returns: + A tuple like (x, y, z) or (x, y, z, core) when the runtime exposes device + coordinates, otherwise None. + """ + coords = device_attr(device, "coords") + if coords is None: + return None + + coords = tuple(coords) + if not coords: + return None + + normalized_coords = tuple(int(coord) for coord in coords) + core_on_chip = device_attr(device, "core_on_chip") + if core_on_chip is None: + return normalized_coords + return normalized_coords + (int(core_on_chip),) + + +def _pool_allows_2d_chip_coords(devices: Sequence[Any]) -> bool: + """Returns whether raw 2D chip coords are valid for this device pool.""" + return _device_family(devices) in {"v5e", "v6e"} + + +def _canonical_device_chip_coords( + device: Any, + *, + allow_2d: bool, +) -> tuple[int, int, int] | None: + """Returns canonical 3D chip coords for one device when available.""" + parsed = topology._device_coords(device) + if parsed is None: + return None + if len(parsed) == 2 and allow_2d: + return topology._canonicalize_chip_shape_to_3d(parsed) + if len(parsed) == 3: + return parsed + return None + + +def infer_core_on_chip_count(devices: Sequence[Any]) -> int | None: + """Returns the per-chip core count when the runtime exposes it consistently.""" + chip_to_cores = collections.defaultdict(set) + saw_any_core = False + + for device in devices: + coords = device_attr(device, "coords") + core_on_chip = device_attr(device, "core_on_chip") + if coords is None: + return None + if core_on_chip is None: + continue + saw_any_core = True + chip_to_cores[tuple(int(coord) for coord in coords)].add(int(core_on_chip)) + + if not saw_any_core: + return None + + core_counts = {len(core_ids) for core_ids in chip_to_cores.values()} + if len(core_counts) != 1: + return None + return next(iter(core_counts)) + + +def summarize_devices_for_logging(devices: Sequence[Any]) -> list[dict[str, Any]]: + """Builds compact log-friendly summaries for a device list. + + Args: + devices: Devices to summarize. + + Returns: + A list of dictionaries containing device id, coords, and inferred host key. + """ + summaries = [] + for device in devices: + summaries.append({ + "id": device_attr(device, "id"), + "coords": device_mesh_coords(device), + "host": device_host_key(device), + }) + return summaries + + +def summarize_devices_for_debug_logging( + devices: Sequence[Any], + limit: int = 16, +) -> list[dict[str, Any]]: + """Builds richer device summaries for topology debugging. + + Args: + devices: Devices to summarize. + limit: Maximum number of devices to include. + + Returns: + A list of dictionaries with raw device topology metadata. + """ + summaries = [] + for device in devices[:limit]: + summaries.append({ + "id": device_attr(device, "id"), + "coords": device_attr(device, "coords"), + "core_on_chip": device_attr(device, "core_on_chip"), + "process_index": device_attr(device, "process_index"), + "task_id": device_attr(device, "task_id"), + "slice_index": device_attr(device, "slice_index"), + "slice": device_attr(device, "slice"), + "host": device_host_key(device), + }) + return summaries + + +def summarize_host_groups_for_logging(devices: Sequence[Any]) -> dict[tuple[Any, ...], int]: + """Summarizes device counts per derived host key for debug logging.""" + host_counts = collections.Counter() + for device in devices: + host_key = device_host_key(device) + host_counts[host_key] += 1 + return dict(sorted(host_counts.items(), key=lambda item: str(item[0]))) + + +def summarize_coord_regions_for_logging( + coord_regions_by_slice: dict[Any, tuple[CoordRegion, ...]] | None, +) -> dict[Any, list[dict[str, tuple[int, ...]]]] | None: + """Builds a compact log-friendly summary of remaining coord regions. + + Args: + coord_regions_by_slice: Remaining coord-region partition keyed by slice id. + + Returns: + A JSON-like summary suitable for logging, or ``None`` when no coord-region + state is available. + """ + if coord_regions_by_slice is None: + return None + return { + slice_id: [ + {"start": region.start, "shape": region.shape} + for region in regions + ] + for slice_id, regions in sorted( + coord_regions_by_slice.items(), + key=lambda item: str(item[0]), + ) + } + + +def _optional_int_sort_key(value: Any) -> tuple[int, int]: + """Builds a deterministic sort key component for optional integer metadata.""" + if value is None: + return (1, 0) + return (0, int(value)) + + +def _host_sort_key(host_key: tuple[Any, ...] | None) -> tuple[tuple[int, Any], ...]: + """Builds a stable sort key for derived host identifiers. + + Args: + host_key: Host grouping key such as `(slice_id, task_id)`, or None when + host metadata is unavailable. + + Returns: + A tuple suitable for deterministic sorting. Missing host metadata sorts + after concrete integer values. + """ + if host_key is None: + return (_optional_int_sort_key(None),) + return tuple(_optional_int_sort_key(value) for value in host_key) + + +def allocation_device_sort_key(device: Any) -> tuple[Any, ...]: + """Sort key for deterministic allocation order: slice, host, coords, id.""" + host_key = device_host_key(device) + host_id = host_key[1] if host_key is not None and len(host_key) > 1 else None + coords = device_mesh_coords(device) + return ( + _optional_int_sort_key(device_slice_id(device)), + _optional_int_sort_key(host_id), + coords or (), + _optional_int_sort_key(device_attr(device, "id")), + ) + + +def group_devices_by_slice(devices: Sequence[Any]) -> list[list[Any]] | None: + """Groups devices by slice while preserving first-seen slice order. + + Devices without slice metadata are treated as belonging to one shared slice. + The order of groups matches the first appearance of each slice in `devices`, + which lets the allocator prefer earlier slices before spilling into later + ones. + """ + slice_to_devices = {} + for device in devices: + slice_id = device_slice_id(device) + slice_to_devices.setdefault(slice_id, []).append(device) + return list(slice_to_devices.values()) + + +def group_devices_by_host(devices: Sequence[Any]) -> list[list[Any]] | None: + """Groups devices by host/task when that metadata is available. + + Args: + devices: Candidate devices to partition. + + Returns: + A list of equal-sized per-host device lists, or None if host metadata is + missing or inconsistent. + + Notes: + Host buckets are returned in deterministic host-key order, and devices + inside each bucket are sorted with `allocation_device_sort_key()` so host + grouping remains reproducible for callers and validations that still use + this helper. + """ + host_to_devices = {} + for device in devices: + host_key = device_host_key(device) + if host_key is None: + return None + host_to_devices.setdefault(host_key, []).append(device) + + host_sizes = {len(host_devices) for host_devices in host_to_devices.values()} + if len(host_sizes) != 1: + logging.warning( + "Falling back to flat device allocation because host sizes differ: %s", + sorted(host_sizes), + ) + return None + ordered_host_keys = sorted(host_to_devices, key=_host_sort_key) + return [ + sorted(host_to_devices[host_key], key=allocation_device_sort_key) + for host_key in ordered_host_keys + ] + + +def get_coord_topology(devices: Sequence[Any]) -> CoordTopology | None: + """Builds normalized coord metadata for a device pool. + + Args: + devices: Candidate devices to inspect. + + Returns: + A CoordTopology describing the device coords and overall bounding box, or + None when the devices do not expose a consistent coord layout. + """ + if not devices: + return None + + coord_to_device = {} + all_coords = [] + has_core_on_chip_dimension = False + for device in devices: + if device_attr(device, "core_on_chip") is not None: + has_core_on_chip_dimension = True + coords = device_mesh_coords(device) + if coords is None: + logging.info( + "Coord topology unavailable because device lacks coords: %s", + summarize_devices_for_debug_logging([device]), + ) + return None + if all_coords and len(coords) != len(all_coords[0]): + logging.info( + "Coord topology unavailable because coord rank differs: existing_rank=%d device=%s", + len(all_coords[0]), + summarize_devices_for_debug_logging([device]), + ) + return None + if coords in coord_to_device: + logging.info( + "Coord topology unavailable because multiple devices share coords %s: %s", + coords, + summarize_devices_for_debug_logging([coord_to_device[coords], device]), + ) + return None + coord_to_device[coords] = device + all_coords.append(coords) + + num_dims = len(all_coords[0]) + min_coords = tuple( + min(coords[dim] for coords in all_coords) + for dim in range(num_dims) + ) + chip_coord_to_coords = collections.defaultdict(list) + for coords in all_coords: + chip_coord = coords[:-1] if has_core_on_chip_dimension else coords + chip_coord_to_coords[chip_coord].append(coords) + max_shape = tuple( + max(coords[dim] for coords in all_coords) + - min(coords[dim] for coords in all_coords) + + 1 + for dim in range(num_dims) + ) + return CoordTopology( + coord_to_device=coord_to_device, + all_coords=tuple(all_coords), + num_dims=num_dims, + min_coords=min_coords, + max_shape=max_shape, + chip_coord_to_coords={ + chip_coord: tuple(sorted(group_coords)) + for chip_coord, group_coords in chip_coord_to_coords.items() + }, + has_core_on_chip_dimension=has_core_on_chip_dimension, + ) + + +def candidate_uses_whole_chips( + coord_topology: CoordTopology, + candidate_coords: Sequence[tuple[int, ...]], +) -> bool: + """Returns whether a candidate includes all logical devices for each chip. + + When multiple logical devices share the same physical chip coordinates, a + valid Pathways subslice must include either all of them or none of them. + This rejects candidates that split `core_on_chip` siblings across meshes. + """ + if not coord_topology.has_core_on_chip_dimension: + return True + + selected_coords = set(candidate_coords) + selected_chip_coords = {coords[:-1] for coords in selected_coords} + for chip_coord in selected_chip_coords: + chip_group = coord_topology.chip_coord_to_coords.get(chip_coord, ()) + if any(coords not in selected_coords for coords in chip_group): + return False + return True + + +def known_host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Returns known host bounds from static topology metadata when available. + + Args: + devices: Devices from a single TPU slice. + + Returns: + A known per-host physical bound in canonical 3D chip form, optionally with + a trailing core-count dimension. + """ + bounds = topology.infer_chips_per_host_bounds(devices) + if bounds is None: + return None + + allow_2d = _pool_allows_2d_chip_coords(devices) + chip_coords = ( + _canonical_device_chip_coords(devices[0], allow_2d=allow_2d) + if devices + else None + ) + if chip_coords is None: + return None + + if device_attr(devices[0], "core_on_chip") is not None: + core_count = infer_core_on_chip_count(devices) + if core_count is None: + return None + return bounds + (core_count,) + + return bounds + + +def resolve_per_host_mesh_shape(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Returns the known per-host shape and validates each host against it. + + Args: + devices: Devices spanning one or more hosts. + + Returns: + The known per-host shape when static topology metadata is available. + + Raises: + ValueError: If any discovered host bucket does not match the known per-host + shape for the device family. + """ + static_shape = known_host_mesh_shape(devices) + if static_shape is None: + return None + + host_groups = group_devices_by_host(devices) + if host_groups is None: + return static_shape + + for host_devices in host_groups: + if not _satisfies_host_bound_shape( + host_devices, + static_shape, + len(host_devices), + ): + raise ValueError( + "Observed host devices do not match known host bounds " + f"{static_shape}." + ) + return static_shape + + +def _divisors(value: int) -> list[int]: + """Returns the positive divisors of `value` in ascending order.""" + divisors = set() + for candidate in range(1, int(np.sqrt(value)) + 1): + if value % candidate == 0: + divisors.add(candidate) + divisors.add(value // candidate) + return sorted(divisors) + + +def _enumerate_box_shapes( + required_devices: int, + max_shape: tuple[int, ...], +) -> list[tuple[int, ...]]: + """Enumerates box shapes whose volume matches the requested device count.""" + shapes = [] + num_dims = len(max_shape) + + def build(dim_index: int, remaining: int, prefix: tuple[int, ...]): + if dim_index == num_dims - 1: + if remaining <= max_shape[dim_index]: + shapes.append(prefix + (remaining,)) + return + + for size in _divisors(remaining): + if size > max_shape[dim_index]: + continue + build(dim_index + 1, remaining // size, prefix + (size,)) + + build(0, required_devices, ()) + return shapes + + +def _device_family(devices: Sequence[Any]) -> str | None: + """Returns the normalized accelerator family for a device pool.""" + if not devices: + return None + device_kind = device_attr(devices[0], "device_kind") + if not isinstance(device_kind, str): + return None + return topology._resolve_family(device_kind) + + +def _supported_coord_box_shapes( + devices: Sequence[Any], + coord_topology: CoordTopology, + required_devices: int, + available_coord_shape: Sequence[int] | None = None, +) -> list[tuple[int, ...]] | None: + """Returns topology-valid physical box shapes for the current device pool. + + When the accelerator family is known, this narrows box search to the exact + TPU topology shapes that can legally realize `required_devices` on the + current cluster. For unknown families, callers should fall back to generic + contiguous-box enumeration. + """ + family = _device_family(devices) + if family is None: + return None + + core_count = infer_core_on_chip_count(devices) or 1 + if required_devices % core_count != 0: + return [] + + chip_rank = coord_topology.num_dims - (1 if coord_topology.has_core_on_chip_dimension else 0) + if chip_rank <= 0: + return [] + available_shape = tuple(available_coord_shape or coord_topology.max_shape) + available_chip_shape = available_shape[:chip_rank] + required_chips = required_devices // core_count + + candidate_chip_shapes = [ + chip_shape + for chip_shape in topology.best_topology_shapes_for_chip_count( + family, + required_chips, + chip_rank=chip_rank, + available_chip_shape=available_chip_shape, + ) + ] + + if not candidate_chip_shapes: + return [] + if coord_topology.has_core_on_chip_dimension: + return [chip_shape + (core_count,) for chip_shape in candidate_chip_shapes] + return candidate_chip_shapes + + +def _full_coord_region(coord_topology: CoordTopology) -> CoordRegion: + """Returns the one bounding coord region for a concrete device pool.""" + return CoordRegion( + start=coord_topology.min_coords, + shape=coord_topology.max_shape, + ) + + +def _create_coord_regions(devices: Sequence[Any]) -> tuple[CoordRegion, ...] | None: + """Builds the initial coord-region partition for one device pool.""" + coord_topology = get_coord_topology(devices) + if coord_topology is None: + return None + return (_full_coord_region(coord_topology),) + + +def _create_coord_regions_by_slice( + devices: Sequence[Any], +) -> dict[Any, tuple[CoordRegion, ...]] | None: + """Builds initial coord-region partitions keyed by slice id when present.""" + slice_groups = group_devices_by_slice(devices) + if not slice_groups: + coord_regions = _create_coord_regions(devices) + if coord_regions is None: + return None + return {None: coord_regions} + + coord_regions_by_slice = {} + for slice_devices in slice_groups: + if not slice_devices: + continue + coord_regions = _create_coord_regions(slice_devices) + if coord_regions is None: + continue + coord_regions_by_slice[device_slice_id(slice_devices[0])] = coord_regions + return coord_regions_by_slice or None + + +def _split_coord_region( + region: CoordRegion, + allocated_start: tuple[int, ...], + allocated_shape: tuple[int, ...], +) -> tuple[CoordRegion, ...]: + """Splits a consumed region into remaining z/y/x-style guillotine regions. + + Args: + region: Region being carved up. + allocated_start: Start coordinate of the allocated box. + allocated_shape: Shape of the allocated box. + + Returns: + The remaining regions after subtracting the allocated box. + + Raises: + ValueError: If the allocated box rank does not match the region or falls + outside the region bounds. + """ + if len(region.shape) != len(allocated_shape) or len(region.shape) != len(allocated_start): + raise ValueError( + "Coord region rank does not match allocated box: " + f"region={region.shape} start={allocated_start} allocated={allocated_shape}." + ) + region_end = tuple( + region.start[dim] + region.shape[dim] + for dim in range(len(region.shape)) + ) + allocated_end = tuple( + allocated_start[dim] + allocated_shape[dim] + for dim in range(len(region.shape)) + ) + if any( + allocated_start[dim] < region.start[dim] + or allocated_end[dim] > region_end[dim] + for dim in range(len(region.shape)) + ): + raise ValueError( + "Allocated box does not fit within coord region: " + f"region_start={region.start} region_shape={region.shape} " + f"allocated_start={allocated_start} allocated_shape={allocated_shape}." + ) + + remaining_regions = [] + num_dims = len(region.shape) + for dim in range(num_dims - 1, -1, -1): + before_dim = allocated_start[dim] - region.start[dim] + after_dim = region_end[dim] - allocated_end[dim] + + if before_dim > 0: + start = list(region.start) + shape = list(region.shape) + for earlier_dim in range(dim): + start[earlier_dim] = allocated_start[earlier_dim] + shape[earlier_dim] = allocated_shape[earlier_dim] + shape[dim] = before_dim + remaining_regions.append(CoordRegion(tuple(start), tuple(shape))) + + if after_dim > 0: + start = list(region.start) + shape = list(region.shape) + for earlier_dim in range(dim): + start[earlier_dim] = allocated_start[earlier_dim] + shape[earlier_dim] = allocated_shape[earlier_dim] + start[dim] = allocated_end[dim] + shape[dim] = after_dim + remaining_regions.append(CoordRegion(tuple(start), tuple(shape))) + return tuple(remaining_regions) + + +def _region_contains_box( + region: CoordRegion, + start: tuple[int, ...], + shape: tuple[int, ...], +) -> bool: + """Returns whether a candidate box lies fully within a coord region.""" + return all( + region.start[dim] <= start[dim] + and start[dim] + shape[dim] <= region.start[dim] + region.shape[dim] + for dim in range(len(region.shape)) + ) + + +def _build_candidate_coords( + coord_topology: CoordTopology, + start: tuple[int, ...], + shape: tuple[int, ...], +) -> list[tuple[int, ...]] | None: + """Builds the coord list for one candidate box if the full box exists.""" + candidate_coords = [] + for offset in np.ndindex(shape): + candidate_coord = tuple( + start[dim] + offset[dim] for dim in range(coord_topology.num_dims) + ) + if candidate_coord not in coord_topology.coord_to_device: + return None + candidate_coords.append(candidate_coord) + return candidate_coords + + +def _coord_box_score( + start: tuple[int, ...], + shape: tuple[int, ...], +) -> tuple[Any, ...]: + """Builds a lexicographic sort key for candidate coord boxes. + + The returned tuple is ordered so Python tuple comparison implements the + desired ranking policy directly: + + 1. Prefer boxes with a smaller maximum dimension. + 2. Prefer more compact overall shapes. + 3. Prefer earlier start coordinates as a stable tiebreaker. + + Args: + start: Candidate box origin. + shape: Candidate box shape. + + Returns: + A tuple sort key suitable for lexicographic comparison. + """ + return ( + max(shape), + tuple(sorted(shape, reverse=True)), + tuple(-dim for dim in shape), + start, + ) + + +def _order_candidate_regions( + candidate_regions: Sequence[tuple[CoordRegion, Sequence[tuple[int, ...]]]], + allocation_policy: str, +) -> list[tuple[CoordRegion, Sequence[tuple[int, ...]]]]: + """Orders candidate regions according to the requested allocation policy. + + Args: + candidate_regions: Candidate remaining regions paired with the supported + shapes each region can realize for the current request. + allocation_policy: ``COMPACT`` or ``PERFORMANCE``. + + Returns: + Candidate regions ordered according to the requested policy. + """ + if allocation_policy == _COMPACT_ALLOCATION_POLICY: + return sorted( + candidate_regions, + key=lambda item: ( + _coord_box_score(item[0].start, item[0].shape), + _coord_box_score(item[0].start, item[1][0]), + ), + ) + return sorted( + candidate_regions, + key=lambda item: ( + _coord_box_score(item[0].start, item[1][0]), + _coord_box_score(item[0].start, item[0].shape), + ), + ) + + +def _find_best_candidate_box( + coord_topology: CoordTopology, + candidate_shapes: Sequence[tuple[int, ...]], + *, + region: CoordRegion | None = None, +) -> tuple[tuple[int, ...], tuple[int, ...], list[tuple[int, ...]]] | None: + """Finds the best valid candidate box, optionally constrained to a region. + + This is the shared internal scan used by both region-aware allocation and + whole-topology fallback allocation. Unlike `find_best_candidate_coords()`, + it returns the full winning box metadata so callers can both assign devices + and update remaining-region state. + + Args: + coord_topology: Normalized coord metadata for the candidate device pool. + candidate_shapes: Exact box shapes to consider. + region: Optional remaining coord region that candidate boxes must fit + inside. When omitted, the scan considers the whole topology. + + Returns: + ``(start, shape, coords)`` for the best-ranked valid box, or ``None`` when + no candidate fits. + """ + best_candidate = None + best_score = None + + for shape in candidate_shapes: + for start in sorted(coord_topology.coord_to_device): + if region is not None and not _region_contains_box(region, start, shape): + continue + candidate_coords = _build_candidate_coords(coord_topology, start, shape) + if candidate_coords is None: + continue + if not candidate_uses_whole_chips(coord_topology, candidate_coords): + continue + score = _coord_box_score(start, shape) + if best_score is None or score < best_score: + best_score = score + best_candidate = (start, shape, candidate_coords) + + return best_candidate + + +def find_best_candidate_coords( + coord_topology: CoordTopology, + required_devices: int, + candidate_shapes: Sequence[tuple[int, ...]] | None = None, +) -> list[tuple[int, ...]] | None: + """Returns only the coord list for the best candidate box. + + Args: + coord_topology: Normalized coord metadata for the candidate device pool. + required_devices: Number of devices needed for one mesh. + candidate_shapes: Optional exact physical shapes to scan instead of + enumerating every factorization of `required_devices`. + + Returns: + The coord list for the best-ranked candidate box, or ``None`` when no + valid box exists. + + Notes: + This is a thin convenience wrapper over `_find_best_candidate_box()` for + callers that only need the selected coordinates. It intentionally discards + the winning box start and shape, which the region-aware allocator still + needs in order to split remaining regions. + """ + shapes = candidate_shapes or _enumerate_box_shapes( + required_devices, + coord_topology.max_shape, + ) + best_candidate = _find_best_candidate_box( + coord_topology, + shapes, + ) + if best_candidate is None: + return None + return list(best_candidate[2]) + + +def _allocate_devices_by_coords( + devices: Sequence[Any], + required_devices: int, + coord_regions: Sequence[CoordRegion] | None = None, + allocation_policy: str = _COMPACT_ALLOCATION_POLICY, +) -> tuple[list[Any] | None, tuple[CoordRegion, ...] | None]: + """Allocates a contiguous physical box of devices when coords exist. + + Args: + devices: Candidate devices to allocate from. + required_devices: Number of devices needed for one mesh. + coord_regions: Optional remaining-region partition to respect during + incremental allocations. + allocation_policy: ``COMPACT`` prefers the smallest fitting remaining + region; ``PERFORMANCE`` prefers the most cubical supported extracted + shape. + + Returns: + A tuple of `(assigned_devices, next_coord_regions)`. `assigned_devices` is + the best contiguous physical box, or None if the devices do not expose + usable coordinates. `next_coord_regions` preserves a guillotine-style + partition when allocation consumed one tracked region from its origin. + + Notes: + This helper runs in three stages: + + 1. Build normalized coord metadata with `get_coord_topology()`. + 2. Derive preferred physical shapes for the device family. + 3. First try the tracked remaining coord regions, which preserves the + guillotine-style region partition used across incremental allocations. + 4. If no tracked region can realize a valid box, fall back to a + whole-topology scan with `find_best_candidate_coords()`. This exists + because the tracked region partition is conservative bookkeeping: it is + useful for incremental allocation, but it does not represent every + contiguous box that may still exist in the remaining device pool. + """ + coord_topology = get_coord_topology(devices) + if coord_topology is None: + return None, None + + allocation_policy = _normalize_allocation_policy(allocation_policy) + regions = tuple(coord_regions or (_full_coord_region(coord_topology),)) + candidate_regions = [] + for region in regions: + candidate_shapes = _supported_coord_box_shapes( + devices, + coord_topology, + required_devices, + available_coord_shape=region.shape, + ) + if not candidate_shapes: + continue + candidate_regions.append((region, candidate_shapes)) + + candidate_regions = _order_candidate_regions(candidate_regions, allocation_policy) + for region, candidate_shapes in candidate_regions: + best_region_candidate = _find_best_candidate_box( + coord_topology, + candidate_shapes, + region=region, + ) + if best_region_candidate is None: + continue + candidate_start, candidate_shape, candidate_coords = best_region_candidate + selected_coords = set(candidate_coords) + assigned_devices = [ + device + for device in devices + if device_mesh_coords(device) in selected_coords + ] + next_coord_regions = tuple( + remaining_region + for existing_region in regions + for remaining_region in ( + _split_coord_region(existing_region, candidate_start, candidate_shape) + if existing_region == region + else (existing_region,) + ) + ) + return assigned_devices, next_coord_regions + + candidate_shapes = _supported_coord_box_shapes( + devices, + coord_topology, + required_devices, + ) + best_candidate_coords = find_best_candidate_coords( + coord_topology, + required_devices, + candidate_shapes=candidate_shapes, + ) + if best_candidate_coords is None: + return None, tuple(regions) + + selected_coords = set(best_candidate_coords) + return [ + device + for device in devices + if device_mesh_coords(device) in selected_coords + ], _create_coord_regions([ + device + for device in devices + if device_mesh_coords(device) not in selected_coords + ]) + + +def _create_device_allocation_state( + devices: Sequence[Any] | None = None, + *, + allocation_policy: str = _COMPACT_ALLOCATION_POLICY, + log_summary: bool = True, +) -> DeviceAllocationState: + """Builds reusable allocator state for one or more mesh allocations. + + This is intentionally private because callers should not need to understand + the allocator internals to request one mesh. The public entry point is + `allocate_devices()`, which accepts either raw `devices` for one-shot use or + an existing `allocation_state` for incremental use. + + Args: + devices: Optional explicit device pool. When omitted, uses ``jax.devices()``. + allocation_policy: Allocation policy carried forward for later incremental + allocations. + log_summary: Whether to emit debug summaries for the initial pool. + + Returns: + A state object containing the remaining device pool plus cached host and + slice capacity metadata used by later allocation calls. + """ + allocation_policy = _normalize_allocation_policy(allocation_policy) + all_devices = tuple( + sorted( + jax.devices() if devices is None else devices, + key=allocation_device_sort_key, + ) + ) + if log_summary: + logging.info( + "Mesh allocator raw device sample: %s", + summarize_devices_for_debug_logging(all_devices), + ) + logging.info( + "Mesh allocator derived host groups: %s", + summarize_host_groups_for_logging(all_devices), + ) + host_groups = group_devices_by_host(all_devices) + slice_groups = group_devices_by_slice(all_devices) + remaining_coord_regions_by_slice = _create_coord_regions_by_slice(all_devices) + if log_summary: + logging.info( + "Mesh allocator derived coord regions: %s", + summarize_coord_regions_for_logging(remaining_coord_regions_by_slice), + ) + full_devices_per_host = ( + len(host_groups[0]) if host_groups else 0 + ) + full_devices_per_slice = None + if slice_groups is not None: + full_devices_per_slice = { + device_slice_id(slice_devices[0]): len(slice_devices) + for slice_devices in slice_groups + if slice_devices + } + host_bound_shape = resolve_per_host_mesh_shape(all_devices) + return DeviceAllocationState( + remaining_devices=all_devices, + full_devices_per_host=full_devices_per_host, + host_bound_shape=host_bound_shape, + total_device_count=len(all_devices), + allocation_policy=allocation_policy, + remaining_coord_regions_by_slice=remaining_coord_regions_by_slice, + full_devices_per_slice=full_devices_per_slice, + ) + + +def _allocate_devices_from_pool( + required_devices: int, + remaining_devices: list[Any], + mesh_name: str, + coord_regions: Sequence[CoordRegion] | None = None, + allocation_policy: str = _COMPACT_ALLOCATION_POLICY, +) -> tuple[list[Any], list[Any], tuple[CoordRegion, ...] | None]: + """Allocates one mesh from a concrete device pool without slice policy. + + This helper contains the pool-local allocation strategy used after any + slice-level decision has already been made. + + Coord-box allocation is mandatory here. If the remaining devices cannot form + a valid coord-based box for the request, allocation fails instead of falling + back to host buckets or flat prefix assignment. + + Args: + required_devices: Number of devices requested. + remaining_devices: Current flat device pool. + mesh_name: Name used for diagnostics. + coord_regions: Optional remaining-region partition for coord allocation. + allocation_policy: Coord-allocation policy to apply. + + Returns: + A tuple of assigned devices, the remaining flat pool, and the updated + coord-region partition. + """ + if required_devices > len(remaining_devices): + raise ValueError( + f"Mesh allocation requires {required_devices} devices for {mesh_name}, " + f"but only {len(remaining_devices)} remain available." + ) + + assigned_devices, next_coord_regions = _allocate_devices_by_coords( + remaining_devices, + required_devices, + coord_regions, + allocation_policy, + ) + if assigned_devices is None: + raise ValueError( + f"Mesh allocation requires {required_devices} devices for {mesh_name}, " + "but coord-based allocation could not construct a valid box from the " + "remaining devices." + ) + + remaining_devices = _remove_devices_by_identity( + remaining_devices, + assigned_devices, + ) + return assigned_devices, remaining_devices, next_coord_regions + + +def allocate_devices( + required_devices: int, + devices: Sequence[Any] | None = None, + *, + mesh_name: str = "allocated_mesh", + allocation_state: DeviceAllocationState | None = None, + allocation_policy: str | None = None, + return_state: bool = False, +) -> list[Any] | tuple[list[Any], DeviceAllocationState]: + """Allocates devices for a single mesh request. + + This is the lowest-level public allocation API. It handles exactly one mesh + request and applies the allocator policy in priority order: + + 1. Reuse or create allocation state for the current remaining device pool. + 2. When one slice group remains, allocate directly from that slice. + 3. When multiple slice groups remain, first try satisfying the request + within one slice. + 4. If no single slice worked, a request may span slices only by consuming + whole remaining slices. + + There are two intended calling modes: + + 1. One-shot allocation: pass `devices=` and receive a single allocation. + 2. Incremental allocation: pass `allocation_state=` and, when + `return_state=True`, receive the updated remaining pool for the next call. + + `allocate_named_mesh_device_slices()` is implemented as a thin loop around + this function. + + Args: + required_devices: Number of devices to allocate for this mesh. + devices: Raw device pool for one-shot use. Mutually exclusive with + `allocation_state`. + mesh_name: Name used only for diagnostics and error messages. + allocation_state: Existing state for incremental allocation. + allocation_policy: Optional allocation policy for one-shot use. When an + existing `allocation_state` is provided, any explicit policy must match + the policy stored in that state. + return_state: Whether to return the updated allocation state alongside the + assigned devices. + + Returns: + Either the assigned device list, or `(assigned_devices, next_state)` when + `return_state=True`. + + Raises: + ValueError: If both `devices` and `allocation_state` are provided, or if + the request cannot be satisfied from the remaining device pool. + """ + if devices is not None and allocation_state is not None: + raise ValueError( + "Pass either devices or allocation_state to allocate_devices, not both." + ) + + owns_state = allocation_state is None + if allocation_state is None: + state = _create_device_allocation_state( + devices, + allocation_policy=_normalize_allocation_policy(allocation_policy), + ) + else: + state = allocation_state + if allocation_policy is not None: + normalized_policy = _normalize_allocation_policy(allocation_policy) + if normalized_policy != state.allocation_policy: + raise ValueError( + "allocation_policy must match allocation_state.allocation_policy, " + f"got {normalized_policy!r} and {state.allocation_policy!r}." + ) + remaining_devices = list(state.remaining_devices) + remaining_coord_regions_by_slice = ( + dict(state.remaining_coord_regions_by_slice) + if state.remaining_coord_regions_by_slice is not None + else None + ) + assigned_devices = None + allocation_error_prefix = ( + f"Mesh allocation requires {required_devices} devices for {mesh_name}, " + ) + + slice_groups = group_devices_by_slice(remaining_devices) + if assigned_devices is None and slice_groups: + for slice_devices in slice_groups: + if len(slice_devices) < required_devices: + continue + slice_state = _create_device_allocation_state( + slice_devices, + log_summary=False, + ) + slice_id = device_slice_id(slice_devices[0]) + slice_coord_regions = None + if remaining_coord_regions_by_slice is not None: + slice_coord_regions = remaining_coord_regions_by_slice.get(slice_id) + assigned_devices, _, next_slice_coord_regions = _allocate_devices_from_pool( + required_devices, + list(slice_state.remaining_devices), + mesh_name, + coord_regions=slice_coord_regions, + allocation_policy=state.allocation_policy, + ) + remaining_devices = _remove_devices_by_identity( + remaining_devices, + assigned_devices, + ) + if remaining_coord_regions_by_slice is not None: + remaining_coord_regions_by_slice[slice_id] = next_slice_coord_regions or () + break + + # If no single slice is large enough, a cross-slice mesh must consume + # whole slices in order. This avoids partial-slice allocation and keeps + # cross-slice policy simple. + if ( + assigned_devices is None + and len(slice_groups) > 1 + and len(remaining_devices) >= required_devices + ): + assigned_devices = [] + assigned_slice_groups = [] + assigned_device_count = 0 + for slice_devices in slice_groups: + if assigned_device_count >= required_devices: + break + slice_id = device_slice_id(slice_devices[0]) + if ( + state.full_devices_per_slice is not None + and len(slice_devices) != state.full_devices_per_slice.get(slice_id) + ): + continue + assigned_slice_groups.append(slice_devices) + assigned_devices.extend(slice_devices) + assigned_device_count += len(slice_devices) + + if assigned_device_count == required_devices: + for slice_devices in assigned_slice_groups: + remaining_devices = _remove_devices_by_identity( + remaining_devices, + slice_devices, + ) + if remaining_coord_regions_by_slice is not None: + remaining_coord_regions_by_slice.pop(device_slice_id(slice_devices[0]), None) + else: + raise ValueError( + allocation_error_prefix + + "but cross-slice allocation only supports whole slices." + ) + + if assigned_devices is None: + raise ValueError( + allocation_error_prefix + + f"but only {len(remaining_devices)} remain available." + ) + + next_state = dataclasses.replace( + state, + remaining_devices=tuple(remaining_devices), + remaining_coord_regions_by_slice=remaining_coord_regions_by_slice, + used_device_count=state.used_device_count + len(assigned_devices), + ) + logging.info( + "Allocated devices for %s: %s", + mesh_name, + summarize_devices_for_logging(assigned_devices), + ) + logging.info( + "Remaining coord regions after %s: %s", + mesh_name, + summarize_coord_regions_for_logging(remaining_coord_regions_by_slice), + ) + + if owns_state and not return_state: + unused_device_count = next_state.total_device_count - next_state.used_device_count + if unused_device_count > 0: + logging.warning( + "Mesh allocation used %d of %d devices; %d devices remain unused.", + next_state.used_device_count, + next_state.total_device_count, + unused_device_count, + ) + + if return_state: + return assigned_devices, next_state + return assigned_devices + + +def _remove_devices_by_identity( + devices: Sequence[Any], + assigned_devices: Sequence[Any], +) -> list[Any]: + """Removes assigned devices from a pool using object identity. + + Identity-based removal avoids ambiguity when test doubles or device objects + compare equal by value but still represent distinct runtime objects. + """ + assigned_device_ids = {id(device) for device in assigned_devices} + return [device for device in devices if id(device) not in assigned_device_ids] + + +def _satisfies_host_bound_shape( + host_devices: Sequence[Any], + host_bound_shape: tuple[int, ...] | None, + host_bound_device_count: int, +) -> bool: + """Returns whether one host bucket matches the expected host topology. + + Args: + host_devices: Devices believed to belong to one host. + host_bound_shape: Expected per-host chip shape, optionally with a trailing + logical-devices-per-chip dimension. + host_bound_device_count: Expected number of logical devices per host. + + Returns: + True when the devices occupy exactly the expected host bounds and device + multiplicity, otherwise False. + """ + if host_bound_shape is None or host_bound_device_count <= 0: + raise ValueError( + "host_bound_shape and host_bound_device_count must be set for " + "host-group allocation." + ) + if len(host_devices) != host_bound_device_count: + return False + + chip_coords = [] + allow_2d = _pool_allows_2d_chip_coords(host_devices) + for device in host_devices: + canonical_coords = _canonical_device_chip_coords(device, allow_2d=allow_2d) + if canonical_coords is None: + return False + chip_coords.append(canonical_coords) + + if not chip_coords: + return False + + num_chip_dims = 3 + if len(host_bound_shape) not in (num_chip_dims, num_chip_dims + 1): + return False + if any(len(coords) != num_chip_dims for coords in chip_coords): + return False + + unique_chip_coords = set(chip_coords) + mins = tuple(min(coords[dim] for coords in unique_chip_coords) for dim in range(num_chip_dims)) + maxs = tuple(max(coords[dim] for coords in unique_chip_coords) for dim in range(num_chip_dims)) + chip_shape = tuple( + maxs[dim] - mins[dim] + 1 for dim in range(num_chip_dims) + ) + expected_chip_shape = host_bound_shape[:num_chip_dims] + if chip_shape != expected_chip_shape: + return False + if int(np.prod(chip_shape)) != len(unique_chip_coords): + return False + + chip_counts = collections.Counter(chip_coords) + if len(host_bound_shape) == num_chip_dims: + return len(set(chip_counts.values())) == 1 + + expected_devices_per_chip = host_bound_shape[-1] + return all(count == expected_devices_per_chip for count in chip_counts.values()) + + +def allocate_named_mesh_device_slices( + mesh_requirements: Sequence[MeshRequirement], + devices: Sequence[Any] | None = None, + *, + allocation_policy: str = _COMPACT_ALLOCATION_POLICY, +) -> dict[str, list[Any]]: + """Allocates device subsets for named meshes. + + This is a convenience wrapper over `allocate_devices()` for callers that want + several named allocations from one shared device pool. + + The function builds one `DeviceAllocationState`, then calls + `allocate_devices()` once per `(mesh_name, required_devices)` pair. That + keeps the single-mesh allocation policy centralized in one public API instead + of duplicating decision logic here. + + Args: + mesh_requirements: Sequence of (mesh_name, required_devices) pairs. + Example: [("actor", 8), ("rollout", 4)]. The mesh_name is only used for + logging and as the key in the returned dictionary. + devices: Optional explicit device list. When omitted, this uses + jax.devices(). + allocation_policy: Allocation policy shared by all requested meshes in this + pass. ``COMPACT`` packs into the smallest fitting remaining region. + ``PERFORMANCE`` prefers more cubical extracted shapes. + + Returns: + A dictionary mapping each mesh name to the list of devices assigned to it. + + Raises: + ValueError: If a requested mesh cannot be assigned enough devices or if a + host-based allocation would split hosts illegally. + """ + state = _create_device_allocation_state( + devices, + allocation_policy=allocation_policy, + ) + allocations = {} + + for mesh_name, required_devices in mesh_requirements: + assigned_devices, state = allocate_devices( + required_devices, + mesh_name=mesh_name, + allocation_state=state, + return_state=True, + ) + allocations[mesh_name] = assigned_devices + + unused_device_count = state.total_device_count - state.used_device_count + if unused_device_count > 0: + logging.warning( + "Mesh allocation used %d of %d devices; %d devices remain unused.", + state.used_device_count, + state.total_device_count, + unused_device_count, + ) + logging.info( + "Mesh device allocation: %s", + {mesh_name: len(assigned_devices) for mesh_name, assigned_devices in allocations.items()}, + ) + return allocations diff --git a/tunix/utils/topology.py b/tunix/utils/topology.py new file mode 100644 index 000000000..91cb463ee --- /dev/null +++ b/tunix/utils/topology.py @@ -0,0 +1,323 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Accelerator topology helpers used by Tunix mesh allocation. + +This module captures two distinct layers of TPU topology knowledge: + +1. Per-host bounds, such as `(2, 2, 1)` chips per host for multi-host fish + families. +2. Supported pod-level physical topology shapes, such as `2x2x4` or `8x16x16`. + +For fish families (`v4`, `v5p`, `tpu7x`), the supported physical pod shapes are +treated as: + +1. A small explicit sequence before the first full cube: `2x2x1`, `2x2x2`, + `2x2x4`, `2x4x4`. +2. Any canonical `4i x 4j x 4k` shape once the topology reaches `4x4x4`, with + `i <= j <= k` in the requested axis order. + +For `v5e` and `v6e`, supported physical shapes are canonicalized to 3D with a +trailing singleton `z`, so an edge shape like `8x16` is treated as `8x16x1`. + +Source references: + +- v4: https://docs.cloud.google.com/tpu/docs/v4 +- v5e: https://docs.cloud.google.com/tpu/docs/v5e +- v5p: https://docs.cloud.google.com/tpu/docs/v5p +- v6e: https://docs.cloud.google.com/tpu/docs/v6e +- v7x: https://docs.cloud.google.com/tpu/docs/v7x +""" + +import collections +import math +from typing import Any, Sequence + +_SINGLE_HOST_BOUNDS = (1, 1, 1) +_MULTI_HOST_BOUNDS = (2, 2, 1) +_SUPPORTED_FAMILIES = { + "v4", + "v5e", + "v5p", + "v6e", + "tpu7x", +} +_FISH_SUPPORTED_SUB_CUBE_SHAPES = ( + (2, 2, 1), + (2, 2, 2), + (2, 2, 4), + (2, 4, 4), +) +_FISH_CUBE_GRANULARITY = 4 +_EDGE_SUPPORTED_SHAPES = ( + (1, 1, 1), + (2, 2, 1), + (2, 4, 1), + (4, 4, 1), + (4, 8, 1), + (8, 8, 1), + (8, 16, 1), + (16, 16, 1), +) + + +def _topology_shape_sort_key(shape: tuple[int, ...]) -> tuple[int, tuple[int, ...], tuple[int, ...]]: + """Ranks valid shapes from more cubical to less cubical.""" + return ( + max(shape), + tuple(sorted(shape, reverse=True)), + shape, + ) + + +def _device_attr(device: Any, attr_name: str) -> Any: + """Returns a raw device attribute, calling it first when exposed lazily.""" + value = getattr(device, attr_name, None) + return value() if callable(value) else value + + +def _normalize_device_kind(device_kind: str) -> str | None: + device_kind = device_kind.lower() + if "v7" in device_kind: + return "tpu7x" + if "v6 lite" in device_kind or "v6e" in device_kind or "v6" in device_kind: + return "v6e" + if "v5 lite" in device_kind or "v5e" in device_kind: + return "v5e" + if "v5" in device_kind: + return "v5p" + if "v4" in device_kind: + return "v4" + return None + + +def _resolve_family(device_kind_or_family: str) -> str | None: + """Resolves a raw device kind or normalized family key to a known family.""" + family = _normalize_device_kind(device_kind_or_family) + if family is not None: + return family + normalized = device_kind_or_family.lower() + if normalized in _SUPPORTED_FAMILIES: + return normalized + return None + + +def _device_host_key(device: Any) -> tuple[Any, ...] | None: + """Returns a stable per-host key when runtime metadata exposes one.""" + task_id = None + for attr_name in ("task_id", "process_index"): + task_id = _device_attr(device, attr_name) + if task_id is not None: + break + if task_id is None: + return None + + slice_id = _device_attr(device, "slice_index") + return (slice_id, task_id) + + +def _device_coords(device: Any) -> tuple[int, ...] | None: + coords = _device_attr(device, "coords") + if coords is None: + return None + return tuple(int(coord) for coord in coords) + + +def _canonicalize_chip_shape_to_3d(shape: Sequence[int]) -> tuple[int, int, int] | None: + """Canonicalizes a chip topology shape to `(x, y, z)` form. + + Shapes may come from edge runtimes that expose 2D chip coordinates. Those + are normalized to 3D by appending a trailing singleton `z` dimension. + """ + parsed = tuple(int(dim) for dim in shape) + if len(parsed) == 2: + return parsed + (1,) + if len(parsed) == 3: + return parsed + return None + + +def _infer_host_shape_from_runtime(devices: Sequence[Any]) -> tuple[int, ...] | None: + """Infers per-host chip bounds from runtime host and coord metadata.""" + host_to_coords = collections.defaultdict(list) + for device in devices: + host_key = _device_host_key(device) + coords = _device_coords(device) + if host_key is None or coords is None: + return None + host_to_coords[host_key].append(coords) + + if not host_to_coords: + return None + + host_shapes = set() + for coords_list in host_to_coords.values(): + if not coords_list: + return None + rank = len(coords_list[0]) + if any(len(coords) != rank for coords in coords_list): + return None + unique_coords = set(coords_list) + mins = tuple(min(coords[dim] for coords in unique_coords) for dim in range(rank)) + maxs = tuple(max(coords[dim] for coords in unique_coords) for dim in range(rank)) + shape = tuple(maxs[dim] - mins[dim] + 1 for dim in range(rank)) + canonical_shape = _canonicalize_chip_shape_to_3d(shape) + if canonical_shape is None: + return None + if int(math.prod(shape)) != len(unique_coords): + return None + host_shapes.add(canonical_shape) + + if len(host_shapes) != 1: + return None + return next(iter(host_shapes)) + + +def best_topology_shapes_for_chip_count( + device_kind_or_family: str, + required_chips: int, + *, + chip_rank: int = 3, + available_chip_shape: Sequence[int] | None = None, +) -> list[tuple[int, ...]]: + """Returns the best legal topology shape(s) for a requested chip count. + + Shapes are ranked from more cubical to less cubical. For fish families this + helper returns only the best-ranked 3D shape(s), which is enough for the + current allocator. For edge families, callers may request either 2D or 3D + shapes. + + Raises: + ValueError: If a fish-family request at or above `4x4x4` is not divisible + by the cube granularity volume. + """ + if required_chips <= 0: + return [] + + family = _resolve_family(device_kind_or_family) + if family is None: + return [] + + parsed_available_shape = None + if available_chip_shape is not None: + parsed_available_shape = tuple(available_chip_shape) + + if family in {"v5e", "v6e"}: + canonical_edge_available_shape = None + if parsed_available_shape is not None: + canonical_edge_available_shape = _canonicalize_chip_shape_to_3d( + parsed_available_shape + ) + matching_shapes = [] + for shape in _EDGE_SUPPORTED_SHAPES: + if math.prod(shape) != required_chips: + continue + if canonical_edge_available_shape is not None: + if any( + dim > limit + for dim, limit in zip(shape, canonical_edge_available_shape) + ): + continue + matching_shapes.append(shape) + if chip_rank == 2: + return [shape[:2] for shape in matching_shapes] + if chip_rank == 3: + return matching_shapes + return [] + + if chip_rank != 3: + return [] + + if parsed_available_shape is not None and len(parsed_available_shape) != 3: + return [] + + best_shape = None + best_shape_key = None + + def consider_shape(shape: tuple[int, int, int]): + nonlocal best_shape, best_shape_key + shape_key = _topology_shape_sort_key(shape) + if best_shape_key is None or shape_key < best_shape_key: + best_shape = shape + best_shape_key = shape_key + + for shape in _FISH_SUPPORTED_SUB_CUBE_SHAPES: + if math.prod(shape) != required_chips: + continue + if parsed_available_shape is not None: + if any(dim > limit for dim, limit in zip(shape, parsed_available_shape)): + continue + best_shape = shape + best_shape_key = _topology_shape_sort_key(shape) + break + + if required_chips >= _FISH_CUBE_GRANULARITY**3: + cube_units, remainder = divmod(required_chips, _FISH_CUBE_GRANULARITY**3) + if remainder != 0: + raise ValueError( + "Fish-family topology requests at or above 4x4x4 must be divisible " + f"by {_FISH_CUBE_GRANULARITY**3} chips, got {required_chips}." + ) + max_i = cube_units + max_j = cube_units + max_k = cube_units + if parsed_available_shape is not None: + max_i = min(max_i, parsed_available_shape[0] // _FISH_CUBE_GRANULARITY) + max_j = min(max_j, parsed_available_shape[1] // _FISH_CUBE_GRANULARITY) + max_k = min(max_k, parsed_available_shape[2] // _FISH_CUBE_GRANULARITY) + i = 1 + while i <= max_i and i * i <= cube_units: + j = i + while j <= max_j and i * j <= cube_units: + k, extra = divmod(cube_units, i * j) + if extra == 0 and j <= k and k <= max_k: + consider_shape( + ( + _FISH_CUBE_GRANULARITY * i, + _FISH_CUBE_GRANULARITY * j, + _FISH_CUBE_GRANULARITY * k, + ) + ) + j += 1 + i += 1 + + if best_shape is None: + return [] + return [best_shape] + + +def infer_chips_per_host_bounds( + devices: Sequence[Any], +) -> tuple[int, ...] | None: + if not devices: + return None + + device_kind = _device_attr(devices[0], "device_kind") + if not isinstance(device_kind, str): + return None + + family = _normalize_device_kind(device_kind) + if family is None: + return None + + runtime_host_shape = _infer_host_shape_from_runtime(devices) + if family in {"v5e", "v6e"} and runtime_host_shape is not None: + return runtime_host_shape + + device_count = len(devices) + if family in {"v5e", "v6e"} and device_count == 1: + return _SINGLE_HOST_BOUNDS + if family == "tpu7x" and device_count == 2: + return _SINGLE_HOST_BOUNDS + return _MULTI_HOST_BOUNDS