Skip to content

JaxPP fails on basic NNX example with scan #7

@qGentry

Description

@qGentry

Hi guys, I'm currently trying to build a proof of concept using JaxPP & Flax.NNX API.
I'm targeting the use-case, where the model initialization is completely independent from JaxPP, uses vmap to create modules, and for each individual PP stage, I'm using scan to iterate over the modules of given stage.
I've implemented the following snippet but for some reason it doesn't work, failing during the tracing of treduce with the following error:

  File "pp_poc.py", line 304, in <module>
    main()
  File "pp_poc.py", line 287, in main
    train_step_jaxpp_fn_compiled = train_step_jaxpp_fn.compile(
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jaxpp/core.py", line 3139, in compile
    self.trace_and_place(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/jaxpp/core.py", line 3156, in trace_and_place
    p, _ = _infer_params(self.fun, self.pjit_info, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 627, in _infer_params
    return _infer_params_internal(fun, ji, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 654, in _infer_params_internal
    p, args_flat = _infer_params_impl(
                   ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 551, in _infer_params_impl
    jaxpr, consts, out_avals = _create_pjit_jaxpr(
                               ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/linear_util.py", line 496, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 1183, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/profiler.py", line 359, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 2409, in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/linear_util.py", line 212, in call_wrapped
    return self.f_transformed(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/api_util.py", line 73, in flatten_fun
    ans = f(*py_args, **py_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/linear_util.py", line 421, in _get_result_paths_thunk
    ans = _fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "pp_poc.py", line 196, in train_step
    loss_acc, grads_acc = jaxpp.api.treduce(
                          ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jaxpp/training.py", line 228, in treduce
    return treduce_i(
           ^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jaxpp/training.py", line 312, in treduce_i
    loop_output = pscan_wrapped(
                  ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jaxpp/training.py", line 99, in pscan_wrapped
    new_out.append(ad.add_jaxvals(out[idx], out[idx + 1]))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/ad_util.py", line 38, in add_jaxvals
    return add_jaxvals_p.bind(x, y)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 632, in bind
    return self._true_bind(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 648, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 660, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 2117, in process_primitive
    return self.default_process_primitive(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 2135, in default_process_primitive
    out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/util.py", line 460, in wrapper
    return cached_call(_multi_weakref_placeholder,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/util.py", line 444, in cache_miss
    return call(*orig_args, **orig_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 1946, in _cached_abstract_eval
    return primitive.abstract_eval(*aval_qdds, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 702, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/ad_util.py", line 50, in add_abstract
    assert core.typematch(x, y), (x, y)
           ^^^^^^^^^^^^^^^^^^^^
AssertionError: (ShapedArray(float32[8,16]), ShapedArray(float32[8,16,16]))

How to run:

CUDA_VISIBLE_DEVICES="0,1,2,3" JAX_RANK=0 python3 pp_poc.py
CUDA_VISIBLE_DEVICES="4,5,6,7" JAX_RANK=1 python3 pp_poc.py

Snippet:

import os
from collections.abc import Callable
from typing import TypeVar

import jax
import jax.numpy as jnp
import jaxpp
import jaxpp.schedules
import optax
from flax import nnx
from jaxpp.pipelining import pipeline_enter_stage


A = TypeVar("A")


def eval_shape_with_sharding(f: Callable[..., A], *args, **kwargs) -> A:
    # Currently flax's nnx.eval_shape does not propagate sharding information.
    # Issue to track: https://github.com/google/flax/issues/5110
    module = nnx.eval_shape(f, *args, **kwargs)
    state = nnx.state(module)
    pspec = nnx.spmd.get_partition_spec(state)

    def wrap_with_sharding(var: nnx.Variable, var_pspec: nnx.Variable) -> nnx.Variable:
        value = var.get_value()
        if not isinstance(value, jax.ShapeDtypeStruct | jax.Array):
            # var.value may be MaskedNode when training subset of parameters
            return var
        new_var = var.copy()
        var_mesh = var.get_metadata().get("mesh", None)
        if var_mesh is not None:
            new_var.set_value(
                jax.ShapeDtypeStruct(
                    shape=value.shape,
                    dtype=value.dtype,
                    sharding=jax.sharding.NamedSharding(
                        mesh=var_mesh,
                        spec=var_pspec.get_value(),
                    ),
                )
            )
        return new_var

    state_with_sharding = jax.tree.map(
        lambda t, spec: wrap_with_sharding(t, spec),
        state,
        pspec,
        is_leaf=lambda x: isinstance(x, nnx.Variable),
    )
    nnx.update(module, state_with_sharding)
    return module


class Block(nnx.Module):
    def __init__(self, in_features, out_features, rngs, mesh=None):
        self.dense1 = nnx.Linear(
            in_features=in_features,
            out_features=out_features,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.lecun_normal(),
                ("fsdp", "tensor"),
                mesh=mesh,
            ),
            rngs=rngs,
        )
        self.dense2 = nnx.Linear(
            in_features=out_features,
            out_features=in_features,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.lecun_normal(),
                ("tensor", "fsdp"),
                mesh=mesh,
            ),
            rngs=rngs,
        )
        self.relu = nnx.relu

    def __call__(self, x):
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dense2(x)
        return x


class Model(nnx.Module):
    def __init__(
        self,
        in_features,
        out_features,
        n_layers,
        rngs: nnx.Rngs,
        mesh=None,
    ):
        self.n_layers = n_layers
        self.in_features = in_features
        self.out_features = out_features
        self.mesh = mesh

        @nnx.split_rngs(splits=n_layers)
        @nnx.vmap(
            in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers"}
        )
        def get_blocks(rngs):
            return self.get_blocks(rngs)

        self.blocks = get_blocks(rngs)

        self.pooling = nnx.Linear(
            in_features=out_features,
            out_features=1,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.lecun_normal(),
                (None, None),
                mesh=mesh,
            ),
            rngs=rngs,
        )

    def get_blocks(self, rngs):
        return Block(
            in_features=self.in_features,
            out_features=self.out_features,
            rngs=rngs,
            mesh=self.mesh,
        )

    def call_blocks(self, x, blocks):
        @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
        def forward(x, block):
            x = block(x)
            return x

        return forward(x, blocks)

    def call_blocks_range(self, x, start, end):
        selected_blocks = jax.tree.map(lambda t: t[start:end], self.blocks)
        return self.call_blocks(x, selected_blocks)

    def __call__(self, x):
        x = self.call_blocks_range(x, 0, self.n_layers)
        out = self.pooling(x)
        return out

    def call_stages(self, x, n_stages):
        assert self.n_layers % n_stages == 0
        blocks_per_stage = self.n_layers // n_stages
        for i in range(n_stages):
            x = self.call_blocks_range(
                x, start=i * blocks_per_stage, end=(i + 1) * blocks_per_stage
            )
            if i != n_stages - 1:
                # keep pooling as part of final stage, don't enter the new one
                x = pipeline_enter_stage(x)
        out = self.pooling(x)
        return out


def get_model(mesh):
    model = Model(
        in_features=16,
        out_features=16,
        n_layers=8,
        rngs=nnx.Rngs(params=0),
        mesh=mesh,
    )
    return model


def get_sharding(obj: nnx.State | nnx.Pytree):
    return jax.tree.map(lambda t: t.sharding, obj)


def get_jaxpp_train_step(
    model,
    optimizer,
    n_minibatches: int,
    mpmd_mesh: jaxpp.api.MpmdMesh,
    schedule: jaxpp.schedules.BaseSchedule,
):
    model_gdef = nnx.graphdef(model)

    def train_step(model, optimizer, x, y):
        def _loop_body(minibatch):
            x_i, y_i = minibatch

            def loss_fn(params: nnx.State):
                model: Model = nnx.merge(model_gdef, params)
                y_pred = model.call_stages(x_i, n_stages=mpmd_mesh.mpmd_dim)
                loss = ((y_i - y_pred) ** 2).mean()
                return loss / n_minibatches

            grad_fn = jax.value_and_grad(loss_fn)
            loss, grads = grad_fn(nnx.state(model))
            return loss, grads

        loss_acc, grads_acc = jaxpp.api.treduce(
            _loop_body,
            (
                x.reshape(n_minibatches, -1, *x.shape[1:]),
                y.reshape(n_minibatches, -1, *y.shape[1:]),
            ),
            schedule=schedule,
            operation=jaxpp.api.Add,
        )
        optimizer.update(model, grads_acc)
        return loss_acc, model, optimizer

    replicated_mpmd_sharding = jax.sharding.NamedSharding(
        mesh=mpmd_mesh.lowering_mesh(), spec=jax.P()
    )
    model_mpmd_sharding = jax.tree.map(
        lambda s: s.update(mesh=mpmd_mesh.lowering_mesh()),
        get_sharding(model),
    )
    opt_mpmd_sharding = jax.tree.map(
        lambda s: s.update(mesh=mpmd_mesh.lowering_mesh()),
        get_sharding(optimizer),
    )

    train_step_fn = jaxpp.api.mpmd_jit_with_loop(
        train_step,
        mpmd_mesh=mpmd_mesh,
        in_shardings=(
            model_mpmd_sharding,
            opt_mpmd_sharding,
            replicated_mpmd_sharding,
            replicated_mpmd_sharding,
        ),
        out_shardings=(
            replicated_mpmd_sharding,
            model_mpmd_sharding,
            opt_mpmd_sharding,
        ),
    )
    return train_step_fn


def main():
    jax.distributed.initialize(
        coordinator_address="0.0.0.0:9444",
        num_processes=2,
        process_id=int(os.environ["JAX_RANK"]),
    )

    n_stages = 2

    spmd_mesh = jax.make_mesh(
        (n_stages, 1, 8 // n_stages), ("layers", "fsdp", "tensor")
    )

    BS = 16

    model_gdef, model_abs_state = nnx.split(
        eval_shape_with_sharding(lambda: get_model(spmd_mesh))
    )
    mpmd_mesh = jaxpp.api.MpmdMesh(spmd_mesh, "layers")

    with jax.set_mesh(spmd_mesh):

        @jax.jit(out_shardings=get_sharding(model_abs_state))
        def get_model_state():
            model = Model(
                in_features=16,
                out_features=16,
                n_layers=8,
                rngs=nnx.Rngs(params=0),
                mesh=spmd_mesh,
            )
            return nnx.state(model)

        model_state = get_model_state()
        model = nnx.merge(model_gdef, model_state)
        optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

    train_step_jaxpp_fn = get_jaxpp_train_step(
        model,
        optimizer,
        n_minibatches=4,
        mpmd_mesh=mpmd_mesh,
        schedule=jaxpp.schedules.Std1F1B(num_stages=mpmd_mesh.mpmd_dim),
    )

    # traced = train_step_jaxpp_fn.trace_and_place(
    #     model, optimizer, jnp.ones((BS, 512, 16)), jnp.ones((BS, 512, 1))
    # )

    train_step_jaxpp_fn_compiled = train_step_jaxpp_fn.compile(
        model,
        optimizer,
        jnp.ones((BS, 512, 16)),
        jnp.ones((BS, 512, 1)),
    )

    for i in range(10):
        loss, model, optimizer = train_step_jaxpp_fn_compiled(
            model,
            optimizer,
            jax.random.normal(jax.random.PRNGKey(i), (BS, 512, 16)),
            jnp.zeros((BS, 512, 1)),
        )
        print(loss)


main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions