@@ -240,7 +240,7 @@ present on the exporting machine:
240240
241241```
242242
243- There is a safety check that will be raise an error when trying to compile
243+ There is a safety check that will raise an error when trying to compile
244244an `Exported` object on a machine that does not have the accelerator
245245for which the code was exported.
246246
@@ -326,7 +326,7 @@ combinations of input shapes.
326326
327327See the {ref}`shape_poly` documentation.
328328
329- # # Device polymorphic export
329+ # # Device- polymorphic export
330330
331331An exported artifact may contain sharding annotations for inputs,
332332outputs and for some intermediates, but these annotations do not refer
@@ -335,20 +335,28 @@ Instead, the sharding annotations refer to logical devices. This
335335means that you can compile and run the exported artifacts on different
336336physical devices that were used for exporting.
337337
338+ The cleanest way to achieve a device- polymorphic export is to
339+ use shardings constructed with a `jax.sharding.AbstractMesh` ,
340+ which contains only the mesh shape and axis names. But,
341+ you can achieve the same results if you use shardings
342+ constructed for a mesh with concrete devices, since the actual
343+ devices in the mesh are ignored for tracing and lowering:
344+
338345```python
339346>> > import jax
340347>> > from jax import export
341- >> > from jax.sharding import Mesh, NamedSharding
348+ >> > from jax.sharding import AbstractMesh, Mesh, NamedSharding
342349>> > from jax.sharding import PartitionSpec as P
350+ >> >
351+ >> > # Use an AbstractMesh for exporting
352+ >> > export_mesh = AbstractMesh(((" a" , 4 ),))
343353
344- >> > # Use the first 4 devices for exporting.
345- >> > export_devices = jax.local_devices()[:4 ]
346- >> > export_mesh = Mesh(export_devices, (" a" ,))
347354>> > def f(x):
348355... return x.T
349356
350- >> > arg = jnp.arange(8 * len (export_devices))
351- >> > exp = export.export(jax.jit(f, in_shardings = (NamedSharding(export_mesh, P(" a" )),)))(arg)
357+ >> > exp = export.export(jax.jit(f))(
358+ ... jax.ShapeDtypeStruct((32 ,), dtype = np.int32,
359+ ... sharding = NamedSharding(export_mesh, P(" a" ))))
352360
353361>> > # `exp` knows for how many devices it was exported.
354362>> > exp.nr_devices
@@ -359,8 +367,20 @@ physical devices that were used for exporting.
359367>> > exp.in_shardings_hlo
360368({devices=[4 ]<= [4 ]},)
361369
370+ >> > # You can also use a concrete set of devices for exporting
371+ >> > concrete_devices = jax.local_devices()[:4 ]
372+ >> > concrete_mesh = Mesh(concrete_devices, (" a" ,))
373+ >> > exp2 = export.export(jax.jit(f))(
374+ ... jax.ShapeDtypeStruct((32 ,), dtype = np.int32,
375+ ... sharding = NamedSharding(concrete_mesh, P(" a" ))))
376+
377+ >> > # You can expect the same results
378+ >> > assert exp.in_shardings_hlo == exp2.in_shardings_hlo
379+
380+ >> > # When you call an Exported, you must use a concrete set of devices
381+ >> > arg = jnp.arange(8 * 4 )
362382>> > res1 = exp.call(jax.device_put(arg,
363- ... NamedSharding(export_mesh , P(" a" ))))
383+ ... NamedSharding(concrete_mesh , P(" a" ))))
364384
365385>> > # Check out the first 2 shards of the result
366386>> > [f " device= { s.device} index= { s.index} " for s in res1.addressable_shards[:2 ]]
@@ -397,9 +417,11 @@ of devices than it was exported for:
397417>> > def f(x):
398418... return x.T
399419
400- >> > arg = jnp.arange(4 * len (export_devices))
401- >> > exp = export.export(jax.jit(f, in_shardings = (NamedSharding(export_mesh, P(" a" )),)))(arg)
420+ >> > exp = export.export(jax.jit(f))(
421+ ... jax.ShapeDtypeStruct((4 * len (export_devices),), dtype = np.int32,
422+ ... sharding = NamedSharding(export_mesh, P(" a" ))))
402423
424+ >> > arg = jnp.arange(4 * len (export_devices))
403425>> > exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL
404426Traceback (most recent call last):
405427ValueError : Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
@@ -420,13 +442,16 @@ artifacts using a new mesh constructed at the call site:
420442>> > def f(x):
421443... return x.T
422444
423- >> > arg = jnp.arange(4 * len (export_devices))
424- >> > exp = export.export(jax.jit(f, in_shardings = (NamedSharding(export_mesh, P(" a" )),)))(arg)
445+
446+ >> > exp = export.export(jax.jit(f))(
447+ ... jax.ShapeDtypeStruct((4 * len (export_devices),), dtype = np.int32,
448+ ... sharding = NamedSharding(export_mesh, P(" a" ))))
425449
426450>> > # Prepare the mesh for calling `exp`.
427451>> > calling_mesh = Mesh(np.array(export_devices[::- 1 ]), (" b" ,))
428452
429453>> > # Shard the arg according to what `exp` expects.
454+ >> > arg = jnp.arange(4 * len (export_devices))
430455>> > sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0 ])
431456>> > res = exp.call(sharded_arg)
432457
0 commit comments