diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 61a7ae267..1099e04b9 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -866,27 +866,6 @@ def remat( static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, ) -> F | tp.Callable[[F], F]: - if isinstance(f, Missing): - return functools.partial( - remat, - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, - ) # type: ignore[return-value] - - return resolve_kwargs()( - graph.update_context('remat')( - general.split_inputs( - jax.checkpoint( - general.merge_inputs(f, ctxtag='remat'), - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, - ), - ctxtag='remat', - ), - ) - ) """A 'lifted' version of the `jax.checkpoint `__ (a.k.a. ``jax.remat``). @@ -901,4 +880,26 @@ def remat( `fundamentals of jax.checkpoint `_ and `practical notes `_. """ + if isinstance(f, Missing): + return functools.partial( + remat, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) # type: ignore[return-value] + + @resolve_kwargs() + @graph.update_context('remat') + @general.split_inputs(ctxtag='remat') + @functools.partial( + jax.checkpoint, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + @general.merge_inputs(ctxtag='remat') + @functools.wraps(f) + def remat_wrapper(*args, **kwargs): + return f(*args, **kwargs) + return remat_wrapper