|
27 | 27 | from tunix.sft import peft_trainer |
28 | 28 | from tunix.tests import test_common as tc |
29 | 29 | from tunix.utils import env_utils |
| 30 | +from tunix.utils import mesh as mesh_lib |
30 | 31 |
|
31 | 32 |
|
32 | 33 | class ConfigTest(parameterized.TestCase): |
@@ -262,7 +263,7 @@ def test_learning_rate_schedule_valid(self, overrides): |
262 | 263 | self.assertIsNotNone(lr_schedule) |
263 | 264 | self.assertTrue(callable(lr_schedule), "lr_schedule should be callable") |
264 | 265 |
|
265 | | - # --- Tests for create_mesh --- |
| 266 | + # --- Tests for mesh config parsing and mesh creation --- |
266 | 267 | @parameterized.named_parameters( |
267 | 268 | dict( |
268 | 269 | testcase_name="valid_1d", |
@@ -311,40 +312,93 @@ def test_create_mesh_valid( |
311 | 312 | ): |
312 | 313 | mock_device_count_fn.return_value = mock_num_devices |
313 | 314 | hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
314 | | - mesh = hp.create_mesh("model_config") |
315 | | - self.assertEqual( |
316 | | - mesh, |
317 | | - jax.make_mesh( |
318 | | - expected[0], |
319 | | - expected[1], |
320 | | - axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), |
321 | | - ), |
322 | | - ) |
| 315 | + axis_shapes, axis_names = hp._parse_mesh_config("model_config") |
| 316 | + expected_mesh = object() |
323 | 317 |
|
324 | | - def test_create_mesh_with_assigned_devices(self): |
325 | | - raw_keys = { |
326 | | - "model_config": { |
327 | | - "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} |
328 | | - } |
329 | | - } |
330 | | - hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
331 | | - assigned_devices = ["d0", "d1", "d2", "d3"] |
| 318 | + with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock: |
| 319 | + mesh = mesh_lib.create_mesh(axis_shapes, axis_names) |
332 | 320 |
|
333 | | - class FakeMesh: |
| 321 | + make_mesh_mock.assert_called_once_with( |
| 322 | + expected[0], |
| 323 | + expected[1], |
| 324 | + axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), |
| 325 | + ) |
| 326 | + self.assertIs(mesh, expected_mesh) |
| 327 | + |
| 328 | + def test_create_mesh_with_assigned_devices(self): |
| 329 | + raw_keys = { |
| 330 | + "model_config": { |
| 331 | + "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} |
| 332 | + } |
| 333 | + } |
| 334 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 335 | + axis_shapes, axis_names = hp._parse_mesh_config("model_config") |
| 336 | + assigned_devices = ["d0", "d1", "d2", "d3"] |
334 | 337 |
|
335 | | - def __init__(self, devices, axis_names, axis_types=None): |
336 | | - self.devices = devices |
337 | | - self.axis_names = axis_names |
338 | | - self.axis_types = axis_types |
| 338 | + class FakeMesh: |
339 | 339 |
|
340 | | - with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): |
341 | | - mesh = hp.create_mesh("model_config", devices=assigned_devices) |
| 340 | + def __init__(self, devices, axis_names, axis_types=None): |
| 341 | + self.devices = devices |
| 342 | + self.axis_names = axis_names |
| 343 | + self.axis_types = axis_types |
342 | 344 |
|
343 | | - self.assertEqual(mesh.devices.shape, (2, 2)) |
344 | | - self.assertSequenceEqual( |
345 | | - mesh.devices.flatten().tolist(), assigned_devices |
| 345 | + with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): |
| 346 | + mesh = mesh_lib.create_mesh( |
| 347 | + axis_shapes, |
| 348 | + axis_names, |
| 349 | + devices=assigned_devices, |
346 | 350 | ) |
347 | | - self.assertEqual(mesh.axis_names, ("x", "y")) |
| 351 | + |
| 352 | + self.assertEqual(mesh.devices.shape, (2, 2)) |
| 353 | + self.assertSequenceEqual( |
| 354 | + mesh.devices.flatten().tolist(), assigned_devices |
| 355 | + ) |
| 356 | + self.assertEqual(mesh.axis_names, ("x", "y")) |
| 357 | + |
| 358 | + def test_parse_mesh_allocation_policy_defaults_to_compact(self): |
| 359 | + raw_keys = { |
| 360 | + "model_config": { |
| 361 | + "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} |
| 362 | + } |
| 363 | + } |
| 364 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 365 | + |
| 366 | + self.assertEqual( |
| 367 | + hp._parse_mesh_allocation_policy("model_config"), |
| 368 | + mesh_lib.normalize_allocation_policy(None), |
| 369 | + ) |
| 370 | + |
| 371 | + def test_parse_mesh_allocation_policy_validates_explicit_value(self): |
| 372 | + raw_keys = { |
| 373 | + "model_config": { |
| 374 | + "mesh": { |
| 375 | + "shape": "(2, 2)", |
| 376 | + "axis_names": "('x', 'y')", |
| 377 | + "allocation_policy": "performance", |
| 378 | + } |
| 379 | + } |
| 380 | + } |
| 381 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 382 | + |
| 383 | + self.assertEqual( |
| 384 | + hp._parse_mesh_allocation_policy("model_config"), |
| 385 | + "PERFORMANCE", |
| 386 | + ) |
| 387 | + |
| 388 | + def test_parse_mesh_allocation_policy_rejects_invalid_value(self): |
| 389 | + raw_keys = { |
| 390 | + "model_config": { |
| 391 | + "mesh": { |
| 392 | + "shape": "(2, 2)", |
| 393 | + "axis_names": "('x', 'y')", |
| 394 | + "allocation_policy": "fastest", |
| 395 | + } |
| 396 | + } |
| 397 | + } |
| 398 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 399 | + |
| 400 | + with self.assertRaisesRegex(ValueError, "allocation_policy must be one of"): |
| 401 | + hp._parse_mesh_allocation_policy("model_config") |
348 | 402 |
|
349 | 403 | @parameterized.named_parameters( |
350 | 404 | dict( |
@@ -424,11 +478,12 @@ def test_create_mesh_invalid( |
424 | 478 | mock_num_devices, |
425 | 479 | error_regex, |
426 | 480 | ): |
427 | | - mock_device_count_fn.return_value = mock_num_devices |
428 | | - with self.assertRaisesRegex(ValueError, error_regex): |
429 | | - nested_dict = self.convert_nested_dict_to_list(raw_keys) |
430 | | - hp = self.initialize_config(nested_dict) |
431 | | - hp.create_mesh("model_config") |
| 481 | + mock_device_count_fn.return_value = mock_num_devices |
| 482 | + with self.assertRaisesRegex(ValueError, error_regex): |
| 483 | + nested_dict = self.convert_nested_dict_to_list(raw_keys) |
| 484 | + hp = self.initialize_config(nested_dict) |
| 485 | + axis_shapes, axis_names = hp._parse_mesh_config("model_config") |
| 486 | + mesh_lib.create_mesh(axis_shapes, axis_names) |
432 | 487 |
|
433 | 488 | @parameterized.named_parameters( |
434 | 489 | dict( |
|
0 commit comments