Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] refactor remat #4662

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,27 +866,6 @@ def remat(
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
) -> F | tp.Callable[[F], F]:
if isinstance(f, Missing):
return functools.partial(
remat,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
) # type: ignore[return-value]

return resolve_kwargs()(
graph.update_context('remat')(
general.split_inputs(
jax.checkpoint(
general.merge_inputs(f, ctxtag='remat'),
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
),
ctxtag='remat',
),
)
)
"""A 'lifted' version of the
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
(a.k.a. ``jax.remat``).
Expand All @@ -901,4 +880,26 @@ def remat(
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
"""
if isinstance(f, Missing):
return functools.partial(
remat,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
) # type: ignore[return-value]

@resolve_kwargs()
@graph.update_context('remat')
@general.split_inputs(ctxtag='remat')
@functools.partial(
jax.checkpoint,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
)
@general.merge_inputs(ctxtag='remat')
@functools.wraps(f)
def remat_wrapper(*args, **kwargs):
return f(*args, **kwargs)

return remat_wrapper
Loading