-
Notifications
You must be signed in to change notification settings - Fork 770
Open
Description
Hey folks, me again.
I've recently faced the following problem when initializing the model with multiple meshes. Basically, output sharding from jitted init_fn returns completely random sharding instead of sticking to specified ones. Also seems like output tensors's mesh actually depends on ordering of the flattened tree. Check out this repro script:
import jax
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import flax.nnx as nnx
mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))
mesh2 = jax.make_mesh((2, 2, 2), ("x", "y", "z"))
rules2 = (("X", "x"), ("Y", "y"), ("Z", "z"))
mesh3 = jax.make_mesh((8,), ("c",))
rules3 = (("C", "c"),)
mesh_data = jax.make_mesh((4, 2), ("data", "context"))
class Model(nnx.Module):
def __init__(self):
self.small_linear1 = nnx.Param(
jnp.ones((16, 16)),
sharding=("A", "B"),
mesh=mesh1,
sharding_rules=rules1,
)
self.small_linear2 = nnx.Param(
jnp.ones((16, 16, 16)),
sharding=("X", "Y", "Z"),
mesh=mesh2,
sharding_rules=rules2,
)
self.small_linear3 = nnx.Param(
jnp.ones((16, 16)),
sharding=("C",),
mesh=mesh3,
sharding_rules=rules3,
)
def init_model_no_jit():
return Model()
@nnx.jit
def init_model_nnx_jit():
model = init_model_no_jit()
return model
with mesh_data:
model_nnx_jit = init_model_nnx_jit()
model_no_jit = init_model_no_jit()
def _print_t_shading(key, t):
print(f"Key: {'.'.join(map(str, key))}, shape: {t.shape}, sharding: {t.sharding}")
print("\nSharding without JIT:")
jax.tree.map_with_path(_print_t_shading, model_no_jit)
print("Sharding with NNX.JIT:")
jax.tree.map_with_path(_print_t_shading, model_nnx_jit)
output:
Sharding without JIT:
Key: .small_linear1..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)
Key: .small_linear2..value, shape: (16, 16, 16), sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('x', 'y', 'z'), memory_kind=device)
Key: .small_linear3..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('c': 8, axis_types=(Auto,)), spec=PartitionSpec('c',), memory_kind=device)
Sharding with NNX.JIT:
Key: .small_linear1..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)
Key: .small_linear2..value, shape: (16, 16, 16), sharding: GSPMDSharding({devices=[2,2,2]<=[8]}, memory_kind=device)
Key: .small_linear3..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(('a', 'b'),), memory_kind=device)
No-JIT version, on the other hand, works correctly (but as one may imagine is not suitable for large-scale init).
Metadata
Metadata
Assignees
Labels
No labels