Skip to content

Commit 398e170

Browse files
hawkinspFlax Authors
authored and
Flax Authors
committed
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates: * {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh * {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec * jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding. PiperOrigin-RevId: 506995236
1 parent 06529c9 commit 398e170

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

docs/guides/flax_on_pjit.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@
286286
"source": [
287287
"## Specify sharding (includes initialization and `TrainState` creation)\n",
288288
"\n",
289-
"Next, generate the [`jax.experimental.pjit.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?#more-information-on-partitionspec) that `pjit` should receive as annotations of _input_ and _output_ data. `PartitionSpec` is a tuple of 2 axes (in a 2x4 mesh). To learn more, refer to [JAX-101: Introduction to `pjit`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html).\n",
289+
"Next, generate the [`jax.sharding.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?#more-information-on-partitionspec) that `pjit` should receive as annotations of _input_ and _output_ data. `PartitionSpec` is a tuple of 2 axes (in a 2x4 mesh). To learn more, refer to [JAX-101: Introduction to `pjit`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html).\n",
290290
"\n",
291291
"### Specify the input\n",
292292
"\n",
@@ -416,7 +416,7 @@
416416
"\n",
417417
"Now you can apply JAX [`pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit) to your `init_fn` in a similar fashion as [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) but with two extra arguments: `in_axis_resources` and `out_axis_resources`.\n",
418418
"\n",
419-
"You need to add a `with mesh:` context when running a `pjit`ted function, so that it can refer to `mesh` (an instance of `jax.experimental.maps.Mesh`) to allocate data on devices correctly."
419+
"You need to add a `with mesh:` context when running a `pjit`ted function, so that it can refer to `mesh` (an instance of `jax.sharding.Mesh`) to allocate data on devices correctly."
420420
]
421421
},
422422
{

docs/guides/flax_on_pjit.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class MLP(nn.Module):
207207

208208
## Specify sharding (includes initialization and `TrainState` creation)
209209

210-
Next, generate the [`jax.experimental.pjit.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?#more-information-on-partitionspec) that `pjit` should receive as annotations of _input_ and _output_ data. `PartitionSpec` is a tuple of 2 axes (in a 2x4 mesh). To learn more, refer to [JAX-101: Introduction to `pjit`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html).
210+
Next, generate the [`jax.sharding.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?#more-information-on-partitionspec) that `pjit` should receive as annotations of _input_ and _output_ data. `PartitionSpec` is a tuple of 2 axes (in a 2x4 mesh). To learn more, refer to [JAX-101: Introduction to `pjit`](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html).
211211

212212
### Specify the input
213213

@@ -275,7 +275,7 @@ state_spec
275275

276276
Now you can apply JAX [`pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit) to your `init_fn` in a similar fashion as [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) but with two extra arguments: `in_axis_resources` and `out_axis_resources`.
277277

278-
You need to add a `with mesh:` context when running a `pjit`ted function, so that it can refer to `mesh` (an instance of `jax.experimental.maps.Mesh`) to allocate data on devices correctly.
278+
You need to add a `with mesh:` context when running a `pjit`ted function, so that it can refer to `mesh` (an instance of `jax.sharding.Mesh`) to allocate data on devices correctly.
279279

280280
```{code-cell} ipython3
281281
:id: a298c5d03c0d

docs/guides/use_checkpointing.ipynb

+5-4
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,8 @@
518518
"outputs": [],
519519
"source": [
520520
"# Multi-host related imports.\n",
521-
"from jax.experimental import maps, PartitionSpec, pjit"
521+
"from jax.sharding import PartitionSpec\n",
522+
"from jax.experimental import pjit"
522523
]
523524
},
524525
{
@@ -531,14 +532,14 @@
531532
"# Create a multi-process array.\n",
532533
"mesh_shape = (4, 2)\n",
533534
"devices = np.asarray(jax.devices()).reshape(*mesh_shape)\n",
534-
"mesh = maps.Mesh(devices, ('x', 'y'))\n",
535+
"mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n",
535536
"\n",
536537
"f = pjit.pjit(\n",
537538
" lambda x: x,\n",
538539
" in_axis_resources=None,\n",
539540
" out_axis_resources=PartitionSpec('x', 'y'))\n",
540541
"\n",
541-
"with maps.Mesh(mesh.devices, mesh.axis_names):\n",
542+
"with jax.sharding.Mesh(mesh.devices, mesh.axis_names):\n",
542543
" mp_array = f(np.arange(8 * 2).reshape(8, 2))\n",
543544
"\n",
544545
"# Make it a pytree as usual.\n",
@@ -619,7 +620,7 @@
619620
}
620621
],
621622
"source": [
622-
"with maps.Mesh(mesh.devices, mesh.axis_names):\n",
623+
"with jax.sharding.Mesh(mesh.devices, mesh.axis_names):\n",
623624
" mp_smaller_array = f(np.zeros(8).reshape(4, 2))\n",
624625
"\n",
625626
"mp_target = {'model': mp_smaller_array}\n",

docs/guides/use_checkpointing.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -255,21 +255,22 @@ Unfortunately, Python Jupyter notebooks are single-host only and cannot activate
255255

256256
```python
257257
# Multi-host related imports.
258-
from jax.experimental import maps, PartitionSpec, pjit
258+
from jax.sharding import PartitionSpec
259+
from jax.experimental import pjit
259260
```
260261

261262
```python
262263
# Create a multi-process array.
263264
mesh_shape = (4, 2)
264265
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
265-
mesh = maps.Mesh(devices, ('x', 'y'))
266+
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
266267

267268
f = pjit.pjit(
268269
lambda x: x,
269270
in_axis_resources=None,
270271
out_axis_resources=PartitionSpec('x', 'y'))
271272

272-
with maps.Mesh(mesh.devices, mesh.axis_names):
273+
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
273274
mp_array = f(np.arange(8 * 2).reshape(8, 2))
274275

275276
# Make it a pytree as usual.
@@ -297,7 +298,7 @@ checkpoints.save_checkpoint_multiprocess(ckpt_dir,
297298
Note that, when using [`flax.training.checkpoints.restore_checkpoint`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.restore_checkpoint), you need to pass a `target` with valid multi-process arrays at the correct structural location. Flax only uses the `target` arrays' meshes and mesh axes to restore the checkpoint. This means that the multi-process array in the `target` arg doesn't have to be as large as your checkpoint's size (the shape of the multi-process array doesn't need to have the same shape as the actual array in your checkpoint).
298299

299300
```python
300-
with maps.Mesh(mesh.devices, mesh.axis_names):
301+
with jax.sharding.Mesh(mesh.devices, mesh.axis_names):
301302
mp_smaller_array = f(np.zeros(8).reshape(4, 2))
302303

303304
mp_target = {'model': mp_smaller_array}

0 commit comments

Comments
 (0)