Skip to content

nnx.clone: add arrays parameter to also copy underlying buffers#5475

Open
dparikh79 wants to merge 1 commit into
google:mainfrom
dparikh79:fix/5461-nnx-clone-shares-buffers
Open

nnx.clone: add arrays parameter to also copy underlying buffers#5475
dparikh79 wants to merge 1 commit into
google:mainfrom
dparikh79:fix/5461-nnx-clone-shares-buffers

Conversation

@dparikh79
Copy link
Copy Markdown

What does this PR do?

Adds arrays=True to nnx.clone per #5461. When set, the underlying jax.Array buffers of each Variable are copied so the clone has independent memory and works with jit(donate_argnums=...). Default False so existing callers see no change. Setting arrays=True forces variables=True because a buffer copy without a fresh wrapper has nowhere to live.

Implementation: jax.tree.map(jax.numpy.copy, state) between split and merge. Used jnp.copy rather than jnp.array to preserve sharding on multi-device arrays.

Expanded the docstring to call out the default buffer-sharing per @samanklesaria's note that it should be better documented.

Two regression tests in tests/nnx/module_test.py. test_clone_arrays_distinct_buffers asserts unsafe_buffer_pointer() differs after arrays=True and matches for the default clone. test_clone_arrays_donate_argnums is @johnlyzhou's exact donate_argnums repro from the issue. Existing test_clone still passes. Full sweep: 77 passed / 1 skipped in module_test.py, 152 passed in graph_utils_test.py.

Fixes #5461

Checklist

nnx.clone returns new Variable wrappers but shares the underlying
jax.Array buffers with the original. This is cheap and matches a
copy-on-write model, but jit(donate_argnums=...) rejects the clone
because JAX refuses to donate the same buffer twice.

Add arrays=True (default False) which runs jax.tree.map(jnp.copy, state)
between split and merge. The default behaviour is unchanged. Using
jnp.copy (rather than jnp.array) preserves sharding on multi-device
arrays. Setting arrays=True implies variables=True since a buffer copy
without a new Variable wrapper is not meaningful.

Two regression tests in tests/nnx/module_test.py:
- test_clone_arrays_distinct_buffers: asserts unsafe_buffer_pointer()
  differs between original and arrays=True clone, and matches between
  original and default clone.
- test_clone_arrays_donate_argnums: reporter's repro from google#5461 using
  jit(donate_argnums=(0,)) on state with a Linear plus a cloned Linear.

Fixes google#5461

Signed-off-by: Dhruvil <dhruvilparikh79@gmail.com>
Copy link
Copy Markdown
Collaborator

@samanklesaria samanklesaria left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nnx.clone creates buffer copies with the same IDs, causing errors with donate_argnums

2 participants