|
815 | 815 | "output_type": "stream", |
816 | 816 | "text": [ |
817 | 817 | "Inner m.param.shape = (3, 5)\n", |
818 | | - "Inner m.param.sharding_names = ('a', None)\n", |
| 818 | + "Inner m.param.out_sharding = ('a', None)\n", |
819 | 819 | "Outter m.param.shape = (3, 4, 5)\n", |
820 | | - "Outter m.param.sharding_names = ('a', 'b', None)\n" |
| 820 | + "Outter m.param.out_sharding = ('a', 'b', None)\n" |
821 | 821 | ] |
822 | 822 | } |
823 | 823 | ], |
824 | 824 | "source": [ |
825 | 825 | "mesh = jax.make_mesh((1, 1), ('a', 'b'))\n", |
826 | 826 | "\n", |
827 | 827 | "class Weights(nnx.Module):\n", |
828 | | - " def __init__(self, array: jax.Array, sharding_names: tuple[str | None, ...]):\n", |
829 | | - " self.param = nnx.Param(array, sharding_names=sharding_names)\n", |
| 828 | + " def __init__(self, array: jax.Array, out_sharding: tuple[str | None, ...]):\n", |
| 829 | + " self.param = nnx.Param(array, out_sharding=out_sharding)\n", |
830 | 830 | "\n", |
831 | 831 | "@nnx.vmap(in_axes=1, transform_metadata={nnx.PARTITION_NAME: 'b'})\n", |
832 | 832 | "def f(m: Weights):\n", |
833 | 833 | " print(f'Inner {m.param.shape = }')\n", |
834 | | - " print(f'Inner {m.param.sharding_names = }')\n", |
| 834 | + " print(f'Inner {m.param.out_sharding = }')\n", |
835 | 835 | "\n", |
836 | 836 | "with jax.set_mesh(mesh):\n", |
837 | | - " m = Weights(jnp.ones((3, 4, 5)), sharding_names=('a', 'b', None))\n", |
| 837 | + " m = Weights(jnp.ones((3, 4, 5)), out_sharding=('a', 'b', None))\n", |
838 | 838 | " f(m)\n", |
839 | 839 | "\n", |
840 | 840 | "print(f'Outter {m.param.shape = }')\n", |
841 | | - "print(f'Outter {m.param.sharding_names = }')" |
| 841 | + "print(f'Outter {m.param.out_sharding = }')" |
842 | 842 | ] |
843 | 843 | }, |
844 | 844 | { |
|
862 | 862 | "output_type": "stream", |
863 | 863 | "text": [ |
864 | 864 | "Outter m.param.shape = (3, 4, 5)\n", |
865 | | - "Outter m.param.sharding_names = ('a', 'b', None)\n" |
| 865 | + "Outter m.param.out_sharding = ('a', 'b', None)\n" |
866 | 866 | ] |
867 | 867 | } |
868 | 868 | ], |
869 | 869 | "source": [ |
870 | 870 | "@nnx.vmap(out_axes=1, axis_size=4, transform_metadata={nnx.PARTITION_NAME: 'b'})\n", |
871 | 871 | "def init_vmap():\n", |
872 | | - " return Weights(jnp.ones((3, 5)), sharding_names=('a', None))\n", |
| 872 | + " return Weights(jnp.ones((3, 5)), out_sharding=('a', None))\n", |
873 | 873 | "\n", |
874 | 874 | "with jax.set_mesh(mesh):\n", |
875 | 875 | " m = init_vmap()\n", |
876 | 876 | "print(f'Outter {m.param.shape = }')\n", |
877 | | - "print(f'Outter {m.param.sharding_names = }')" |
| 877 | + "print(f'Outter {m.param.out_sharding = }')" |
878 | 878 | ] |
879 | 879 | } |
880 | 880 | ], |
|
0 commit comments