|
27 | 27 | from tunix.sft import peft_trainer |
28 | 28 | from tunix.tests import test_common as tc |
29 | 29 | from tunix.utils import env_utils |
| 30 | +from tunix.utils import mesh as mesh_lib |
| 31 | + |
| 32 | +os.environ.setdefault("HF_TOKEN", "TestToken") |
30 | 33 |
|
31 | 34 |
|
32 | 35 | class ConfigTest(parameterized.TestCase): |
@@ -262,7 +265,7 @@ def test_learning_rate_schedule_valid(self, overrides): |
262 | 265 | self.assertIsNotNone(lr_schedule) |
263 | 266 | self.assertTrue(callable(lr_schedule), "lr_schedule should be callable") |
264 | 267 |
|
265 | | - # --- Tests for create_mesh --- |
| 268 | + # --- Tests for mesh config parsing and mesh creation --- |
266 | 269 | @parameterized.named_parameters( |
267 | 270 | dict( |
268 | 271 | testcase_name="valid_1d", |
@@ -311,40 +314,93 @@ def test_create_mesh_valid( |
311 | 314 | ): |
312 | 315 | mock_device_count_fn.return_value = mock_num_devices |
313 | 316 | hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
314 | | - mesh = hp.create_mesh("model_config") |
315 | | - self.assertEqual( |
316 | | - mesh, |
317 | | - jax.make_mesh( |
318 | | - expected[0], |
319 | | - expected[1], |
320 | | - axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), |
321 | | - ), |
322 | | - ) |
| 317 | + axis_shapes, axis_names = hp._parse_mesh_config("model_config") |
| 318 | + expected_mesh = object() |
323 | 319 |
|
324 | | - def test_create_mesh_with_assigned_devices(self): |
325 | | - raw_keys = { |
326 | | - "model_config": { |
327 | | - "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} |
328 | | - } |
329 | | - } |
330 | | - hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
331 | | - assigned_devices = ["d0", "d1", "d2", "d3"] |
| 320 | + with mock.patch.object(jax, "make_mesh", return_value=expected_mesh) as make_mesh_mock: |
| 321 | + mesh = mesh_lib.create_mesh(axis_shapes, axis_names) |
332 | 322 |
|
333 | | - class FakeMesh: |
| 323 | + make_mesh_mock.assert_called_once_with( |
| 324 | + expected[0], |
| 325 | + expected[1], |
| 326 | + axis_types=(jax.sharding.AxisType.Auto,) * len(expected[1]), |
| 327 | + ) |
| 328 | + self.assertIs(mesh, expected_mesh) |
| 329 | + |
| 330 | + def test_create_mesh_with_assigned_devices(self): |
| 331 | + raw_keys = { |
| 332 | + "model_config": { |
| 333 | + "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} |
| 334 | + } |
| 335 | + } |
| 336 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 337 | + axis_shapes, axis_names = hp._parse_mesh_config("model_config") |
| 338 | + assigned_devices = ["d0", "d1", "d2", "d3"] |
334 | 339 |
|
335 | | - def __init__(self, devices, axis_names, axis_types=None): |
336 | | - self.devices = devices |
337 | | - self.axis_names = axis_names |
338 | | - self.axis_types = axis_types |
| 340 | + class FakeMesh: |
339 | 341 |
|
340 | | - with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): |
341 | | - mesh = hp.create_mesh("model_config", devices=assigned_devices) |
| 342 | + def __init__(self, devices, axis_names, axis_types=None): |
| 343 | + self.devices = devices |
| 344 | + self.axis_names = axis_names |
| 345 | + self.axis_types = axis_types |
342 | 346 |
|
343 | | - self.assertEqual(mesh.devices.shape, (2, 2)) |
344 | | - self.assertSequenceEqual( |
345 | | - mesh.devices.flatten().tolist(), assigned_devices |
| 347 | + with mock.patch.object(jax.sharding, "Mesh", side_effect=FakeMesh): |
| 348 | + mesh = mesh_lib.create_mesh( |
| 349 | + axis_shapes, |
| 350 | + axis_names, |
| 351 | + devices=assigned_devices, |
346 | 352 | ) |
347 | | - self.assertEqual(mesh.axis_names, ("x", "y")) |
| 353 | + |
| 354 | + self.assertEqual(mesh.devices.shape, (2, 2)) |
| 355 | + self.assertSequenceEqual( |
| 356 | + mesh.devices.flatten().tolist(), assigned_devices |
| 357 | + ) |
| 358 | + self.assertEqual(mesh.axis_names, ("x", "y")) |
| 359 | + |
| 360 | + def test_parse_mesh_allocation_policy_defaults_to_compact(self): |
| 361 | + raw_keys = { |
| 362 | + "model_config": { |
| 363 | + "mesh": {"shape": "(2, 2)", "axis_names": "('x', 'y')"} |
| 364 | + } |
| 365 | + } |
| 366 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 367 | + |
| 368 | + self.assertEqual( |
| 369 | + hp._parse_mesh_allocation_policy("model_config"), |
| 370 | + mesh_lib.normalize_allocation_policy(None), |
| 371 | + ) |
| 372 | + |
| 373 | + def test_parse_mesh_allocation_policy_validates_explicit_value(self): |
| 374 | + raw_keys = { |
| 375 | + "model_config": { |
| 376 | + "mesh": { |
| 377 | + "shape": "(2, 2)", |
| 378 | + "axis_names": "('x', 'y')", |
| 379 | + "allocation_policy": "performance", |
| 380 | + } |
| 381 | + } |
| 382 | + } |
| 383 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 384 | + |
| 385 | + self.assertEqual( |
| 386 | + hp._parse_mesh_allocation_policy("model_config"), |
| 387 | + "PERFORMANCE", |
| 388 | + ) |
| 389 | + |
| 390 | + def test_parse_mesh_allocation_policy_rejects_invalid_value(self): |
| 391 | + raw_keys = { |
| 392 | + "model_config": { |
| 393 | + "mesh": { |
| 394 | + "shape": "(2, 2)", |
| 395 | + "axis_names": "('x', 'y')", |
| 396 | + "allocation_policy": "fastest", |
| 397 | + } |
| 398 | + } |
| 399 | + } |
| 400 | + hp = self.initialize_config(self.convert_nested_dict_to_list(raw_keys)) |
| 401 | + |
| 402 | + with self.assertRaisesRegex(ValueError, "allocation_policy must be one of"): |
| 403 | + hp._parse_mesh_allocation_policy("model_config") |
348 | 404 |
|
349 | 405 | @parameterized.named_parameters( |
350 | 406 | dict( |
@@ -424,11 +480,12 @@ def test_create_mesh_invalid( |
424 | 480 | mock_num_devices, |
425 | 481 | error_regex, |
426 | 482 | ): |
427 | | - mock_device_count_fn.return_value = mock_num_devices |
428 | | - with self.assertRaisesRegex(ValueError, error_regex): |
429 | | - nested_dict = self.convert_nested_dict_to_list(raw_keys) |
430 | | - hp = self.initialize_config(nested_dict) |
431 | | - hp.create_mesh("model_config") |
| 483 | + mock_device_count_fn.return_value = mock_num_devices |
| 484 | + with self.assertRaisesRegex(ValueError, error_regex): |
| 485 | + nested_dict = self.convert_nested_dict_to_list(raw_keys) |
| 486 | + hp = self.initialize_config(nested_dict) |
| 487 | + axis_shapes, axis_names = hp._parse_mesh_config("model_config") |
| 488 | + mesh_lib.create_mesh(axis_shapes, axis_names) |
432 | 489 |
|
433 | 490 | @parameterized.named_parameters( |
434 | 491 | dict( |
|
0 commit comments