Skip to content

Commit d3f3672

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Rename sharding_names to out_sharding in NNX Variable metadata
This CL renames the sharding_names attribute to out_sharding for better consistency with the sharding API. The new name more clearly indicates the purpose of this metadata field. ## Changes - Bump Flax version to 0.12.4 - Core changes in variablelib.py: - Add sharding_names to out_sharding metadata remapping for backward compatibility - Add deprecated sharding_names property that returns out_sharding with a warning - Update nnx/spmd.py, core/spmd.py, core/meta.py, linen/spmd.py to use out_sharding - Update all NNX tests to use the new attribute name - Update qwix flax_util.py to check for out_sharding first, with fallback to sharding_names - Update maxtext initializers.py to check for out_sharding first - Update documentation and examples to use out_sharding ## Backward Compatibility Existing code using sharding_names will continue to work via: - Metadata remapping during Variable creation - Deprecated Variable.sharding_names property PiperOrigin-RevId: 859745972
1 parent 4089582 commit d3f3672

File tree

18 files changed

+90
-80
lines changed

18 files changed

+90
-80
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ To cite this repository:
158158
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
159159
title = {{F}lax: A neural network library and ecosystem for {JAX}},
160160
url = {http://github.com/google/flax},
161-
version = {0.12.3},
161+
version = {0.12.4},
162162
year = {2024},
163163
}
164164
```

docs_nnx/flip/4844-var-eager-sharding.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ with jax.set_mesh(mesh):
5656
...
5757
```
5858

59-
For JAX explicit mode, remove the `sharding_names=` annotation on the `nnx.Variable`.
59+
For JAX explicit mode, remove the `out_sharding=` annotation on the `nnx.Variable`.
6060

6161

6262
# Implementation
6363
[implementation]: #implementation
6464

65-
When an `nnx.Variable` is created, check for the metadata `sharding_names`, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call `jax.lax.with_sharding_constraint` to apply sharding constraint on the value.
65+
When an `nnx.Variable` is created, check for the metadata `out_sharding`, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call `jax.lax.with_sharding_constraint` to apply sharding constraint on the value.
6666

6767
Note that this only works in auto sharding mode. User should use JAX-level APIs to annotate shardings for explicit mode.

docs_nnx/guides/flax_gspmd.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
"metadata": {},
124124
"outputs": [],
125125
"source": [
126-
"nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)"
126+
"nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)"
127127
]
128128
},
129129
{
@@ -134,7 +134,7 @@
134134
"\n",
135135
"Let's begin by sharding the simplest component possible - a Flax variable.\n",
136136
"\n",
137-
"When you define a Flax variable, you can pass in a metadata field called `sharding_names`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.\n",
137+
"When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.\n",
138138
"\n",
139139
"**You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook."
140140
]
@@ -191,7 +191,7 @@
191191
"with jax.set_mesh(auto_mesh):\n",
192192
" w = nnx.Param(\n",
193193
" rngs.lecun_normal()((4, 8)),\n",
194-
" sharding_names=(None, 'model')\n",
194+
" out_sharding=(None, 'model')\n",
195195
" )\n",
196196
" print(w.sharding.spec)\n",
197197
" jax.debug.visualize_array_sharding(w) # already sharded!"

docs_nnx/guides/flax_gspmd.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ with nnx.use_eager_sharding(False):
6464
You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way.
6565

6666
```{code-cell} ipython3
67-
nnx.Param(jnp.ones(4,4), sharding_names=(None, 'model'), eager_sharding=True, mesh=auto_mesh)
67+
nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)
6868
```
6969

7070
## Shard a single-array model
7171

7272
Let's begin by sharding the simplest component possible - a Flax variable.
7373

74-
When you define a Flax variable, you can pass in a metadata field called `sharding_names`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.
74+
When you define a Flax variable, you can pass in a metadata field called `out_sharding`, to specify how the underlying JAX array should be sharded. This field should be a tuple of names, each of which refer to how an axis of the array should be sharded.
7575

7676
**You must have an existing device mesh** and create a sharding-annotated `nnx.Variable` within its scope. This allows the result variable to be sharded accordingly on those devices. The device mesh can be your actual accelerator mesh, or a dummy fake CPU mesh like in this notebook.
7777

@@ -81,7 +81,7 @@ rngs = nnx.Rngs(0)
8181
with jax.set_mesh(auto_mesh):
8282
w = nnx.Param(
8383
rngs.lecun_normal()((4, 8)),
84-
sharding_names=(None, 'model')
84+
out_sharding=(None, 'model')
8585
)
8686
print(w.sharding.spec)
8787
jax.debug.visualize_array_sharding(w) # already sharded!

docs_nnx/guides/transforms.ipynb

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -815,30 +815,30 @@
815815
"output_type": "stream",
816816
"text": [
817817
"Inner m.param.shape = (3, 5)\n",
818-
"Inner m.param.sharding_names = ('a', None)\n",
818+
"Inner m.param.out_sharding = ('a', None)\n",
819819
"Outter m.param.shape = (3, 4, 5)\n",
820-
"Outter m.param.sharding_names = ('a', 'b', None)\n"
820+
"Outter m.param.out_sharding = ('a', 'b', None)\n"
821821
]
822822
}
823823
],
824824
"source": [
825825
"mesh = jax.make_mesh((1, 1), ('a', 'b'))\n",
826826
"\n",
827827
"class Weights(nnx.Module):\n",
828-
" def __init__(self, array: jax.Array, sharding_names: tuple[str | None, ...]):\n",
829-
" self.param = nnx.Param(array, sharding_names=sharding_names)\n",
828+
" def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]):\n",
829+
" self.param = nnx.Param(array, out_sharding=out_sharding)\n",
830830
"\n",
831831
"@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})\n",
832832
"def f(m: Weights):\n",
833833
" print(f'Inner {m.param.shape = }')\n",
834-
" print(f'Inner {m.param.sharding_names = }')\n",
834+
" print(f'Inner {m.param.out_sharding = }')\n",
835835
"\n",
836836
"with jax.set_mesh(mesh):\n",
837-
" m = Weights(jnp.ones((3, 4, 5)), sharding_names=('a', 'b', None))\n",
837+
" m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None))\n",
838838
" f(m)\n",
839839
"\n",
840840
"print(f'Outter {m.param.shape = }')\n",
841-
"print(f'Outter {m.param.sharding_names = }')"
841+
"print(f'Outter {m.param.out_sharding = }')"
842842
]
843843
},
844844
{
@@ -862,19 +862,19 @@
862862
"output_type": "stream",
863863
"text": [
864864
"Outter m.param.shape = (3, 4, 5)\n",
865-
"Outter m.param.sharding_names = ('a', 'b', None)\n"
865+
"Outter m.param.out_sharding = ('a', 'b', None)\n"
866866
]
867867
}
868868
],
869869
"source": [
870870
"@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})\n",
871871
"def init_vmap():\n",
872-
" return Weights(jnp.ones((3, 5)), sharding_names=('a', None))\n",
872+
" return Weights(jnp.ones((3, 5)), out_sharding=('a', None))\n",
873873
"\n",
874874
"with jax.set_mesh(mesh):\n",
875875
" m = init_vmap()\n",
876876
"print(f'Outter {m.param.shape = }')\n",
877-
"print(f'Outter {m.param.sharding_names = }')"
877+
"print(f'Outter {m.param.out_sharding = }')"
878878
]
879879
}
880880
],

docs_nnx/guides/transforms.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,20 +391,20 @@ Let's see an example of this in action:
391391
mesh = jax.make_mesh((1, 1), ('a', 'b'))
392392
393393
class Weights(nnx.Module):
394-
def __init__(self, array: jax.Array, sharding_names: tuple[str | None, ...]):
395-
self.param = nnx.Param(array, sharding_names=sharding_names)
394+
def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]):
395+
self.param = nnx.Param(array, out_sharding=out_sharding)
396396
397397
@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})
398398
def f(m: Weights):
399399
print(f'Inner {m.param.shape = }')
400-
print(f'Inner {m.param.sharding_names = }')
400+
print(f'Inner {m.param.out_sharding = }')
401401
402402
with jax.set_mesh(mesh):
403-
m = Weights(jnp.ones((3, 4, 5)), sharding_names=('a', 'b', None))
403+
m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None))
404404
f(m)
405405
406406
print(f'Outter {m.param.shape = }')
407-
print(f'Outter {m.param.sharding_names = }')
407+
print(f'Outter {m.param.out_sharding = }')
408408
```
409409

410410
Here, you added a `sharding` metadata to the `nnx.Param` variables, and used `transform_metadata` to update the `sharding` metadata to reflect the axis changes. Specifically, you can see that the first axis `b` was removed from the `sharding` metadata when inside of `nnx.vmap`, and then added back when outside of `nnx.vmap`.
@@ -414,10 +414,10 @@ You can verify that this also works when `nnx.Module`s are created inside the tr
414414
```{code-cell} ipython3
415415
@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})
416416
def init_vmap():
417-
return Weights(jnp.ones((3, 5)), sharding_names=('a', None))
417+
return Weights(jnp.ones((3, 5)), out_sharding=('a', None))
418418
419419
with jax.set_mesh(mesh):
420420
m = init_vmap()
421421
print(f'Outter {m.param.shape = }')
422-
print(f'Outter {m.param.sharding_names = }')
422+
print(f'Outter {m.param.out_sharding = }')
423423
```

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ class MLP(nnx.Module):
5656
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
5757
self.w1 = nnx.Param(
5858
nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
59-
sharding_names=mesh_rules('embed', 'mlp'),
59+
out_sharding=mesh_rules('embed', 'mlp'),
6060
)
6161
self.b1 = nnx.Param(
6262
jnp.zeros((dmid,)),
63-
sharding_names=mesh_rules('mlp'),
63+
out_sharding=mesh_rules('mlp'),
6464
)
6565
self.w2 = nnx.Param(
6666
nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
67-
sharding_names=mesh_rules('embed', 'mlp'),
67+
out_sharding=mesh_rules('embed', 'mlp'),
6868
)
6969

7070
def __call__(self, x: jax.Array):

flax/core/meta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,13 +297,13 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
297297
def to_nnx_metadata(self) -> dict[str, Any]:
298298
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
299299
metadata = dict(vars(self))
300-
metadata['sharding_names'] = metadata.pop('names')
300+
metadata['out_sharding'] = metadata.pop('names')
301301
return metadata
302302

303303
@classmethod
304304
def from_nnx_metadata(cls, metadata: dict[str, Any]):
305305
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
306-
metadata['names'] = metadata.pop('sharding_names')
306+
metadata['names'] = metadata.pop('out_sharding')
307307
fields = {x.name for x in dataclasses.fields(cls)}
308308
return cls(**{k: v for k, v in metadata.items() if k in fields})
309309

flax/core/spmd.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
Sharding,
2525
)
2626

27-
def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec:
27+
def get_pspec(sharding, sharding_rules = None) -> PartitionSpec:
2828
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
2929
if get_logical_axis_rules() or sharding_rules:
3030
context_rules = get_logical_axis_rules()
3131
rules = composite_rules(context_rules, sharding_rules)
32-
return PartitionSpec(*from_sharding_rules(sharding_names, rules))
33-
return PartitionSpec(*sharding_names)
32+
return PartitionSpec(*from_sharding_rules(sharding, rules))
33+
return PartitionSpec(*sharding)
3434

3535
def _apply_sharding(value, sharding, mesh):
3636
if mesh.are_all_axes_explicit:
@@ -44,10 +44,9 @@ def _apply_sharding(value, sharding, mesh):
4444

4545

4646
def shard_value(
47-
value, sharding_names, sharding_rules,
48-
mesh: jax.sharding.AbstractMesh | jax.sharding.Mesh | None
47+
value, sharding, sharding_rules, mesh: jax.sharding.AbstractMesh | jax.sharding.Mesh | None
4948
):
50-
if not sharding_names:
49+
if not sharding:
5150
return value
5251

5352
if mesh is None:
@@ -56,9 +55,9 @@ def shard_value(
5655
if mesh is None:
5756
raise ValueError(
5857
'An auto mesh context or metadata is required if creating a variable'
59-
f' with annotation {sharding_names=}. '
58+
f' with annotation {sharding=}. '
6059
'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.')
61-
pspec = get_pspec(sharding_names, sharding_rules)
60+
pspec = get_pspec(sharding, sharding_rules)
6261
return _apply_sharding(value, NamedSharding(mesh, pspec), mesh)
6362

6463

flax/linen/spmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,15 @@ def to_nnx_metadata(self) -> dict[str, Any]:
303303
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
304304
metadata = vars(self)
305305
if 'names' in metadata:
306-
metadata['sharding_names'] = metadata.pop('names')
306+
metadata['out_sharding'] = metadata.pop('names')
307307
if 'rules' in metadata:
308308
metadata['sharding_rules'] = metadata.pop('rules')
309309
return metadata
310310

311311
@classmethod
312312
def from_nnx_metadata(cls, metadata: dict[str, Any]):
313313
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
314-
metadata['names'] = metadata.pop('sharding_names')
314+
metadata['names'] = metadata.pop('out_sharding')
315315
metadata['rules'] = metadata.pop('sharding_rules')
316316
fields = {x.name for x in dataclasses.fields(cls)}
317317
return cls(**{k: v for k, v in metadata.items() if k in fields})

0 commit comments

Comments
 (0)