-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Hi guys, I'm currently trying to build a proof of concept using JaxPP & Flax.NNX API.
I'm targeting the use-case, where the model initialization is completely independent from JaxPP, uses vmap to create modules, and for each individual PP stage, I'm using scan to iterate over the modules of given stage.
I've implemented the following snippet but for some reason it doesn't work, failing during the tracing of treduce with the following error:
File "pp_poc.py", line 304, in <module>
main()
File "pp_poc.py", line 287, in main
train_step_jaxpp_fn_compiled = train_step_jaxpp_fn.compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jaxpp/core.py", line 3139, in compile
self.trace_and_place(*args, **kwargs)
File "/usr/local/lib/python3.11/dist-packages/jaxpp/core.py", line 3156, in trace_and_place
p, _ = _infer_params(self.fun, self.pjit_info, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 627, in _infer_params
return _infer_params_internal(fun, ji, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 654, in _infer_params_internal
p, args_flat = _infer_params_impl(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 551, in _infer_params_impl
jaxpr, consts, out_avals = _create_pjit_jaxpr(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/linear_util.py", line 496, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/pjit.py", line 1183, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/profiler.py", line 359, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 2409, in trace_to_jaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/linear_util.py", line 212, in call_wrapped
return self.f_transformed(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/api_util.py", line 73, in flatten_fun
ans = f(*py_args, **py_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/linear_util.py", line 421, in _get_result_paths_thunk
ans = _fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "pp_poc.py", line 196, in train_step
loss_acc, grads_acc = jaxpp.api.treduce(
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jaxpp/training.py", line 228, in treduce
return treduce_i(
^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jaxpp/training.py", line 312, in treduce_i
loop_output = pscan_wrapped(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jaxpp/training.py", line 99, in pscan_wrapped
new_out.append(ad.add_jaxvals(out[idx], out[idx + 1]))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/ad_util.py", line 38, in add_jaxvals
return add_jaxvals_p.bind(x, y)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 632, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 648, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 660, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 2117, in process_primitive
return self.default_process_primitive(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 2135, in default_process_primitive
out_avals, effs = _cached_abstract_eval(primitive, *aval_qdds, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/util.py", line 460, in wrapper
return cached_call(_multi_weakref_placeholder,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/util.py", line 444, in cache_miss
return call(*orig_args, **orig_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/interpreters/partial_eval.py", line 1946, in _cached_abstract_eval
return primitive.abstract_eval(*aval_qdds, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/core.py", line 702, in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/jax/_src/ad_util.py", line 50, in add_abstract
assert core.typematch(x, y), (x, y)
^^^^^^^^^^^^^^^^^^^^
AssertionError: (ShapedArray(float32[8,16]), ShapedArray(float32[8,16,16]))How to run:
CUDA_VISIBLE_DEVICES="0,1,2,3" JAX_RANK=0 python3 pp_poc.py
CUDA_VISIBLE_DEVICES="4,5,6,7" JAX_RANK=1 python3 pp_poc.pySnippet:
import os
from collections.abc import Callable
from typing import TypeVar
import jax
import jax.numpy as jnp
import jaxpp
import jaxpp.schedules
import optax
from flax import nnx
from jaxpp.pipelining import pipeline_enter_stage
A = TypeVar("A")
def eval_shape_with_sharding(f: Callable[..., A], *args, **kwargs) -> A:
# Currently flax's nnx.eval_shape does not propagate sharding information.
# Issue to track: https://github.com/google/flax/issues/5110
module = nnx.eval_shape(f, *args, **kwargs)
state = nnx.state(module)
pspec = nnx.spmd.get_partition_spec(state)
def wrap_with_sharding(var: nnx.Variable, var_pspec: nnx.Variable) -> nnx.Variable:
value = var.get_value()
if not isinstance(value, jax.ShapeDtypeStruct | jax.Array):
# var.value may be MaskedNode when training subset of parameters
return var
new_var = var.copy()
var_mesh = var.get_metadata().get("mesh", None)
if var_mesh is not None:
new_var.set_value(
jax.ShapeDtypeStruct(
shape=value.shape,
dtype=value.dtype,
sharding=jax.sharding.NamedSharding(
mesh=var_mesh,
spec=var_pspec.get_value(),
),
)
)
return new_var
state_with_sharding = jax.tree.map(
lambda t, spec: wrap_with_sharding(t, spec),
state,
pspec,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
nnx.update(module, state_with_sharding)
return module
class Block(nnx.Module):
def __init__(self, in_features, out_features, rngs, mesh=None):
self.dense1 = nnx.Linear(
in_features=in_features,
out_features=out_features,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(),
("fsdp", "tensor"),
mesh=mesh,
),
rngs=rngs,
)
self.dense2 = nnx.Linear(
in_features=out_features,
out_features=in_features,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(),
("tensor", "fsdp"),
mesh=mesh,
),
rngs=rngs,
)
self.relu = nnx.relu
def __call__(self, x):
x = self.dense1(x)
x = self.relu(x)
x = self.dense2(x)
return x
class Model(nnx.Module):
def __init__(
self,
in_features,
out_features,
n_layers,
rngs: nnx.Rngs,
mesh=None,
):
self.n_layers = n_layers
self.in_features = in_features
self.out_features = out_features
self.mesh = mesh
@nnx.split_rngs(splits=n_layers)
@nnx.vmap(
in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers"}
)
def get_blocks(rngs):
return self.get_blocks(rngs)
self.blocks = get_blocks(rngs)
self.pooling = nnx.Linear(
in_features=out_features,
out_features=1,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(),
(None, None),
mesh=mesh,
),
rngs=rngs,
)
def get_blocks(self, rngs):
return Block(
in_features=self.in_features,
out_features=self.out_features,
rngs=rngs,
mesh=self.mesh,
)
def call_blocks(self, x, blocks):
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, block):
x = block(x)
return x
return forward(x, blocks)
def call_blocks_range(self, x, start, end):
selected_blocks = jax.tree.map(lambda t: t[start:end], self.blocks)
return self.call_blocks(x, selected_blocks)
def __call__(self, x):
x = self.call_blocks_range(x, 0, self.n_layers)
out = self.pooling(x)
return out
def call_stages(self, x, n_stages):
assert self.n_layers % n_stages == 0
blocks_per_stage = self.n_layers // n_stages
for i in range(n_stages):
x = self.call_blocks_range(
x, start=i * blocks_per_stage, end=(i + 1) * blocks_per_stage
)
if i != n_stages - 1:
# keep pooling as part of final stage, don't enter the new one
x = pipeline_enter_stage(x)
out = self.pooling(x)
return out
def get_model(mesh):
model = Model(
in_features=16,
out_features=16,
n_layers=8,
rngs=nnx.Rngs(params=0),
mesh=mesh,
)
return model
def get_sharding(obj: nnx.State | nnx.Pytree):
return jax.tree.map(lambda t: t.sharding, obj)
def get_jaxpp_train_step(
model,
optimizer,
n_minibatches: int,
mpmd_mesh: jaxpp.api.MpmdMesh,
schedule: jaxpp.schedules.BaseSchedule,
):
model_gdef = nnx.graphdef(model)
def train_step(model, optimizer, x, y):
def _loop_body(minibatch):
x_i, y_i = minibatch
def loss_fn(params: nnx.State):
model: Model = nnx.merge(model_gdef, params)
y_pred = model.call_stages(x_i, n_stages=mpmd_mesh.mpmd_dim)
loss = ((y_i - y_pred) ** 2).mean()
return loss / n_minibatches
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(nnx.state(model))
return loss, grads
loss_acc, grads_acc = jaxpp.api.treduce(
_loop_body,
(
x.reshape(n_minibatches, -1, *x.shape[1:]),
y.reshape(n_minibatches, -1, *y.shape[1:]),
),
schedule=schedule,
operation=jaxpp.api.Add,
)
optimizer.update(model, grads_acc)
return loss_acc, model, optimizer
replicated_mpmd_sharding = jax.sharding.NamedSharding(
mesh=mpmd_mesh.lowering_mesh(), spec=jax.P()
)
model_mpmd_sharding = jax.tree.map(
lambda s: s.update(mesh=mpmd_mesh.lowering_mesh()),
get_sharding(model),
)
opt_mpmd_sharding = jax.tree.map(
lambda s: s.update(mesh=mpmd_mesh.lowering_mesh()),
get_sharding(optimizer),
)
train_step_fn = jaxpp.api.mpmd_jit_with_loop(
train_step,
mpmd_mesh=mpmd_mesh,
in_shardings=(
model_mpmd_sharding,
opt_mpmd_sharding,
replicated_mpmd_sharding,
replicated_mpmd_sharding,
),
out_shardings=(
replicated_mpmd_sharding,
model_mpmd_sharding,
opt_mpmd_sharding,
),
)
return train_step_fn
def main():
jax.distributed.initialize(
coordinator_address="0.0.0.0:9444",
num_processes=2,
process_id=int(os.environ["JAX_RANK"]),
)
n_stages = 2
spmd_mesh = jax.make_mesh(
(n_stages, 1, 8 // n_stages), ("layers", "fsdp", "tensor")
)
BS = 16
model_gdef, model_abs_state = nnx.split(
eval_shape_with_sharding(lambda: get_model(spmd_mesh))
)
mpmd_mesh = jaxpp.api.MpmdMesh(spmd_mesh, "layers")
with jax.set_mesh(spmd_mesh):
@jax.jit(out_shardings=get_sharding(model_abs_state))
def get_model_state():
model = Model(
in_features=16,
out_features=16,
n_layers=8,
rngs=nnx.Rngs(params=0),
mesh=spmd_mesh,
)
return nnx.state(model)
model_state = get_model_state()
model = nnx.merge(model_gdef, model_state)
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
train_step_jaxpp_fn = get_jaxpp_train_step(
model,
optimizer,
n_minibatches=4,
mpmd_mesh=mpmd_mesh,
schedule=jaxpp.schedules.Std1F1B(num_stages=mpmd_mesh.mpmd_dim),
)
# traced = train_step_jaxpp_fn.trace_and_place(
# model, optimizer, jnp.ones((BS, 512, 16)), jnp.ones((BS, 512, 1))
# )
train_step_jaxpp_fn_compiled = train_step_jaxpp_fn.compile(
model,
optimizer,
jnp.ones((BS, 512, 16)),
jnp.ones((BS, 512, 1)),
)
for i in range(10):
loss, model, optimizer = train_step_jaxpp_fn_compiled(
model,
optimizer,
jax.random.normal(jax.random.PRNGKey(i), (BS, 512, 16)),
jnp.zeros((BS, 512, 1)),
)
print(loss)
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels