Skip to content

Examples how to use fori_loop for gradient accumulation & more clear exceptions #5113

@qGentry

Description

@qGentry

Hey folks,

It would be nice to have some examples in Flax NNX documentation on how to do gradient accumulation with nnx.fori_loop.

I've tried lots of different combination of "how to pass arguments, what I can add to closure and what I can't" myself and haven't succeed.

My train step looks something like this:

# TrainSession is NamedTuple with 'model: nnx.Module' and 'optimizer: nnx.Optimizer' fields

def get_training_step(
    forward_pass_cfg: ForwardPassConfigBase,
    batch_sharding: PyTree,
    num_minibatches: int = 1,
) -> Callable[
    [PyTree, TrainSession],
    tuple[DeviceArray, PyTree],
]:
    forward_pass_fn = get_forward_pass(forward_pass_cfg)

    @nnx.jit(donate_argnums=(1,))
    def training_step(
        batch: PyTree,
        train_session: TrainSession,
    ) -> tuple[DeviceArray, PyTree]:

        def _loop_body(
            minibatch_idx,
            carry: tuple[nnx.State, DeviceArray, TrainSession],
        ) -> tuple[nnx.State, DeviceArray, TrainSession]:
            g_accum, loss_accum, train_session = carry
            minibatch = get_sharded_minibatch(
                batch=batch,
                batch_sharding=batch_sharding,
                minibatch_idx=minibatch_idx,
                num_minibatches=num_minibatches,
            )

            def _loss_fn(model) -> tuple[DeviceArray, PyTree]:
                return forward_pass_fn(
                    model=model,
                    batch=minibatch,
                    step=train_session.optimizer.step.value,
                    config=forward_pass_cfg,
                )

            (loss, outputs), grads = nnx.value_and_grad(_loss_fn, has_aux=True)(
                train_session.model
            )
            g_accum = jax.tree.map(
                lambda gm, g: gm + g, g_accum, grads
            )
            loss_accum = loss_accum + loss
            return g_accum, loss_accum, train_session

        g_accum = jax.tree.map(jnp.zeros_like, nnx.state(train_session.model))
        g_accum, loss_accum, train_session = nnx.fori_loop(
            lower=0, 
            upper=num_minibatches,
            body_fun=_loop_body,
            init_val=(g_accum, 0.0, train_session)
        )

        g_accum = jax.tree.map(lambda g: g / num_minibatches, g_accum)
        loss_accum = loss_accum / num_minibatches

        train_session.optimizer.update(train_session.model, g_accum)
        return loss_accum, {}

    return training_step

And I'm getting the following error:

ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from `outer_index` field, you might have modified reference structure within the body function, which is not allowed.

But I'm not sure I what does it mean and how can I fix it.

It would be nice if flax provided examples on documentation on how to approach it properly, what can be done, what can't be done, how to work with closures, etc.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions