Skip to content

Commit 6e84e21

Browse files
BrianWiederOrbax Authors
authored andcommitted
Raise a more useful error when doing a dtensor export with pspecs but not all params are jax arrays.
PiperOrigin-RevId: 875283768
1 parent d83491f commit 6e84e21

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

export/orbax/export/modules/tensorflow_module.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,22 @@ def jax_params_to_tf_variables(
276276
if not all(
277277
isinstance(x, jax.Array) for x in jax.tree_util.tree_leaves(params)
278278
):
279-
logging.warning(
280-
'Some params are not jax.Array, DTensor export will not take'
281-
' effect.Falling back to traditional TF export.'
279+
flattened_params = jax.tree_util.tree_leaves_with_path(params)
280+
non_jax_array_info = [
281+
f'{jax.tree_util.keystr(path)} (type: {type(x).__name__})'
282+
for path, x in flattened_params
283+
if not isinstance(x, jax.Array)
284+
]
285+
if len(non_jax_array_info) > 10:
286+
omitted = len(non_jax_array_info) - 10
287+
non_jax_array_info = non_jax_array_info[:10] + [
288+
f'...and {omitted} more'
289+
]
290+
raise ValueError(
291+
'All parameters must be JAX arrays when DTensor export is enabled.'
292+
' Found non-JAX array parameters at:'
293+
f' {", ".join(non_jax_array_info)}'
282294
)
283-
mesh = None
284295

285296
if mesh is None and pspecs is not None:
286297
raise ValueError(

export/orbax/export/modules/tensorflow_module_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import collections
1616
import logging
1717
import os
18+
from unittest import mock
1819

1920
from absl.testing import parameterized
2021
import chex
@@ -180,6 +181,30 @@ def test_jax_array(self):
180181
self.assertEqual(variables[0].name, 'arr:0')
181182
self.assertAllEqual(variables[0], global_input_data)
182183

184+
def test_dtensor_export_non_jax_array_error(self):
185+
global_mesh = jax.sharding.Mesh(
186+
np.array(jax.local_devices(backend='cpu')), 'x'
187+
)
188+
mesh_axes = jax.sharding.PartitionSpec('x')
189+
190+
arr = jnp.array([1, 2, 3])
191+
with mock.patch.object(
192+
tensorflow_module.dtensor_utils,
193+
'get_current_mesh',
194+
return_value=global_mesh,
195+
):
196+
with self.assertRaisesRegex(
197+
ValueError,
198+
r'All parameters must be JAX arrays when DTensor export is enabled.*'
199+
r"Found non-JAX array parameters at: \['some_non_array'\] \(type:"
200+
r' ellipsis\)',
201+
):
202+
TensorFlowModule(
203+
params={'arr': arr, 'some_non_array': ...},
204+
apply_fn=DEFAULT_APPLY_FN,
205+
pspecs={'arr': mesh_axes, 'some_non_array': None},
206+
)
207+
183208
@parameterized.parameters(True, False)
184209
def test_call_tf_module_methods(self, jit_compile):
185210

0 commit comments

Comments
 (0)