nnx.clone: add arrays parameter to also copy underlying buffers#5475
Open
dparikh79 wants to merge 1 commit into
Open
nnx.clone: add arrays parameter to also copy underlying buffers#5475dparikh79 wants to merge 1 commit into
dparikh79 wants to merge 1 commit into
Conversation
nnx.clone returns new Variable wrappers but shares the underlying jax.Array buffers with the original. This is cheap and matches a copy-on-write model, but jit(donate_argnums=...) rejects the clone because JAX refuses to donate the same buffer twice. Add arrays=True (default False) which runs jax.tree.map(jnp.copy, state) between split and merge. The default behaviour is unchanged. Using jnp.copy (rather than jnp.array) preserves sharding on multi-device arrays. Setting arrays=True implies variables=True since a buffer copy without a new Variable wrapper is not meaningful. Two regression tests in tests/nnx/module_test.py: - test_clone_arrays_distinct_buffers: asserts unsafe_buffer_pointer() differs between original and arrays=True clone, and matches between original and default clone. - test_clone_arrays_donate_argnums: reporter's repro from google#5461 using jit(donate_argnums=(0,)) on state with a Linear plus a cloned Linear. Fixes google#5461 Signed-off-by: Dhruvil <dhruvilparikh79@gmail.com>
samanklesaria
approved these changes
Jun 1, 2026
Collaborator
samanklesaria
left a comment
There was a problem hiding this comment.
Looks good to me!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Adds
arrays=Truetonnx.cloneper #5461. When set, the underlyingjax.Arraybuffers of eachVariableare copied so the clone has independent memory and works withjit(donate_argnums=...). DefaultFalseso existing callers see no change. Settingarrays=Trueforcesvariables=Truebecause a buffer copy without a fresh wrapper has nowhere to live.Implementation:
jax.tree.map(jax.numpy.copy, state)betweensplitandmerge. Usedjnp.copyrather thanjnp.arrayto preserve sharding on multi-device arrays.Expanded the docstring to call out the default buffer-sharing per @samanklesaria's note that it should be better documented.
Two regression tests in
tests/nnx/module_test.py.test_clone_arrays_distinct_buffersassertsunsafe_buffer_pointer()differs afterarrays=Trueand matches for the default clone.test_clone_arrays_donate_argnumsis @johnlyzhou's exactdonate_argnumsrepro from the issue. Existingtest_clonestill passes. Full sweep: 77 passed / 1 skipped inmodule_test.py, 152 passed ingraph_utils_test.py.Fixes #5461
Checklist