Skip to content

Commit cc5036c

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Raise a better error message if anything other than a sequence of ints is passed to make_mesh or create_device_mesh
PiperOrigin-RevId: 700779838
1 parent 6e72592 commit cc5036c

File tree

4 files changed

+21
-1
lines changed

4 files changed

+21
-1
lines changed

jax/_src/mesh_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,12 @@ def _transpose_trick(
705705
*_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims]
706706
)
707707

708+
def _validate_axis_shapes(axis_shapes: Sequence[int], arg_name: str,
709+
fun_name: str):
710+
if not all(isinstance(s, int) for s in axis_shapes):
711+
raise ValueError(
712+
f'{arg_name} passed to {fun_name} should be a sequence of ints. Got'
713+
f' {axis_shapes}')
708714

709715
def create_device_mesh(
710716
mesh_shape: Sequence[int],
@@ -740,7 +746,8 @@ def create_device_mesh(
740746
"""
741747
if devices is None:
742748
devices = xb.devices()
743-
if np.prod(mesh_shape) != len(devices):
749+
_validate_axis_shapes(mesh_shape, 'mesh_shape', 'create_device_mesh')
750+
if math.prod(mesh_shape) != len(devices):
744751
raise ValueError(
745752
f'Number of devices {len(devices)} must equal the product '
746753
f'of mesh_shape {mesh_shape}'

jax/_src/sharding_impls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,7 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
17141714
"""
17151715
if devices is None:
17161716
devices = xla_bridge.devices()
1717+
mesh_utils._validate_axis_shapes(axis_shapes, 'axis_shapes', 'make_mesh')
17171718
axis_size = math.prod(axis_shapes)
17181719
if axis_size > len(devices):
17191720
raise ValueError(

tests/mesh_utils_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,12 @@ def test_create_device_mesh_for_nd_torus(
353353
)
354354
self.assertArraysEqual(assignment, expected_assignment_matrix)
355355

356+
def test_create_device_mesh_non_int_error(self):
357+
with self.assertRaisesRegex(
358+
ValueError,
359+
"mesh_shape passed to create_device_mesh should be a sequence of ints"):
360+
mesh_utils.create_device_mesh(((4,), 4))
361+
356362
@parameterized.named_parameters(
357363
('2x2x1', mock_2x2x1_devices,),
358364
('2x2x4', mock_2x2x4_devices, ),

tests/pjit_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4458,6 +4458,12 @@ def g(x):
44584458
self.assertEqual(out2.sharding, s)
44594459
self.assertEqual(out2.dtype, np.float32)
44604460

4461+
def test_make_mesh_non_int_error(self):
4462+
with self.assertRaisesRegex(
4463+
ValueError,
4464+
"axis_shapes passed to make_mesh should be a sequence of ints"):
4465+
jax.make_mesh(((4,), 4), ('x', 'y'))
4466+
44614467
def test_jnp_array_reshard_error(self):
44624468
if jax.device_count() < 2:
44634469
self.skipTest('Requires >=2 devices')

0 commit comments

Comments
 (0)