Skip to content

Out sharding for modules initialized with JIT is incorrect #5127

@qGentry

Description

@qGentry

Hey folks, me again.

I've recently faced the following problem when initializing the model with multiple meshes. Basically, output sharding from jitted init_fn returns completely random sharding instead of sticking to specified ones. Also seems like output tensors's mesh actually depends on ordering of the flattened tree. Check out this repro script:

import jax
import flax.nnx as nnx

import jax
import jax.numpy as jnp
import flax.nnx as nnx


mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))
mesh2 = jax.make_mesh((2, 2, 2), ("x", "y", "z"))
rules2 = (("X", "x"), ("Y", "y"), ("Z", "z"))
mesh3 = jax.make_mesh((8,), ("c",))
rules3 = (("C", "c"),)

mesh_data = jax.make_mesh((4, 2), ("data", "context"))


class Model(nnx.Module):
    def __init__(self):
        self.small_linear1 = nnx.Param(
            jnp.ones((16, 16)), 
            sharding=("A", "B"), 
            mesh=mesh1,
            sharding_rules=rules1,
        )
        self.small_linear2 = nnx.Param(
            jnp.ones((16, 16, 16)), 
            sharding=("X", "Y", "Z"), 
            mesh=mesh2,
            sharding_rules=rules2,
        )
        self.small_linear3 = nnx.Param(
            jnp.ones((16, 16)),
            sharding=("C",), 
            mesh=mesh3,
            sharding_rules=rules3,
        )


def init_model_no_jit():
    return Model()


@nnx.jit
def init_model_nnx_jit():
    model = init_model_no_jit()
    return model


with mesh_data:
    model_nnx_jit = init_model_nnx_jit()
    model_no_jit = init_model_no_jit()

    def _print_t_shading(key, t):
        print(f"Key: {'.'.join(map(str, key))}, shape: {t.shape}, sharding: {t.sharding}")

    print("\nSharding without JIT:")
    jax.tree.map_with_path(_print_t_shading, model_no_jit)

    print("Sharding with NNX.JIT:")
    jax.tree.map_with_path(_print_t_shading, model_nnx_jit)

output:

Sharding without JIT:
Key: .small_linear1..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)
Key: .small_linear2..value, shape: (16, 16, 16), sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('x', 'y', 'z'), memory_kind=device)
Key: .small_linear3..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('c': 8, axis_types=(Auto,)), spec=PartitionSpec('c',), memory_kind=device)
Sharding with NNX.JIT:
Key: .small_linear1..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)
Key: .small_linear2..value, shape: (16, 16, 16), sharding: GSPMDSharding({devices=[2,2,2]<=[8]}, memory_kind=device)
Key: .small_linear3..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(('a', 'b'),), memory_kind=device)

No-JIT version, on the other hand, works correctly (but as one may imagine is not suitable for large-scale init).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions