fix(jax): include in_out_argnames and stage argnames in FFI registry …#1249
fix(jax): include in_out_argnames and stage argnames in FFI registry …#1249Adityakk9031 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
…cache keys Signed-off-by: Aditya kumar singh <143548997+Adityakk9031@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds hashable components ( Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Greptile SummaryFixes a caching bug where FFI wrappers with different argument configurations were incorrectly deduplicated. The registry cache keys now include
Confidence Score: 5/5
Important Files Changed
Last reviewed commit: dad527b |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
warp/_src/jax_experimental/ffi.py (2)
1206-1214: Usefrozenset(ortuple(sorted(...))) for order-independent cache key semantics.
in_out_argnamesis immediately converted to asetinFfiKernel.__init__(line 170), so argument order is semantically irrelevant. Usingtuple(in_out_argnames)makes the cache key order-sensitive:["a", "b"]and["b", "a"]produce different keys but create behaviorally identicalFfiKernelobjects, causing redundant FFI target registrations.♻️ Proposed fix
- hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None + hashable_in_out = frozenset(in_out_argnames) if in_out_argnames is not None else None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@warp/_src/jax_experimental/ffi.py` around lines 1206 - 1214, The cache key currently uses tuple(in_out_argnames) which is order-sensitive even though FfiKernel.__init__ turns in_out_argnames into a set, causing equivalent kernels with different arg order to miss cache hits; change the construction of hashable_in_out to an order-independent representation (e.g., frozenset(in_out_argnames) or tuple(sorted(in_out_argnames))) and update the key creation that references hashable_in_out so identical FfiKernel objects produce the same cache key.
1553-1564: Same order-sensitivity concern for all three hashable argname variables injax_callable.All three parameters —
in_out_argnames,stage_in_argnames,stage_out_argnames— are stored assets insideFfiCallable.__init__(lines 529, 525, 526 respectively), so their ordering is semantically irrelevant. Usingtuple(...)makes the cache key order-dependent, causing unnecessary duplicateFfiCallableregistrations and FFI target registrations.♻️ Proposed fix
- hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None - hashable_stage_in = tuple(stage_in_argnames) if stage_in_argnames is not None else None - hashable_stage_out = tuple(stage_out_argnames) if stage_out_argnames is not None else None + hashable_in_out = frozenset(in_out_argnames) if in_out_argnames is not None else None + hashable_stage_in = frozenset(stage_in_argnames) if stage_in_argnames is not None else None + hashable_stage_out = frozenset(stage_out_argnames) if stage_out_argnames is not None else None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@warp/_src/jax_experimental/ffi.py` around lines 1553 - 1564, The cache key in jax_callable is order-sensitive because in_out_argnames, stage_in_argnames, and stage_out_argnames (which are stored as sets in FfiCallable.__init__) are converted with tuple(...); change those conversions to an order-insensitive representation (e.g., use frozenset(...) or tuple(sorted(...))) when building key so the key does not depend on arbitrary set iteration order and avoids duplicate FfiCallable/FFI registrations; update the creation of hashable_in_out, hashable_stage_in, and hashable_stage_out accordingly where key is constructed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1206-1214: The cache key currently uses tuple(in_out_argnames)
which is order-sensitive even though FfiKernel.__init__ turns in_out_argnames
into a set, causing equivalent kernels with different arg order to miss cache
hits; change the construction of hashable_in_out to an order-independent
representation (e.g., frozenset(in_out_argnames) or
tuple(sorted(in_out_argnames))) and update the key creation that references
hashable_in_out so identical FfiKernel objects produce the same cache key.
- Around line 1553-1564: The cache key in jax_callable is order-sensitive
because in_out_argnames, stage_in_argnames, and stage_out_argnames (which are
stored as sets in FfiCallable.__init__) are converted with tuple(...); change
those conversions to an order-insensitive representation (e.g., use
frozenset(...) or tuple(sorted(...))) when building key so the key does not
depend on arbitrary set iteration order and avoids duplicate FfiCallable/FFI
registrations; update the creation of hashable_in_out, hashable_stage_in, and
hashable_stage_out accordingly where key is constructed.
|
@shi-eric have a look |
|
@c0d1f1ed have a look |
|
@Adityakk9031 Do you have a need for these changes? What is your use case? |
|
@christophercrouzet The fix addresses a silent correctness bug where reusing the same kernel with different in_out_argnames or staging configurations would return a stale cached wrapper. This ensures JAX FFI registrations correctly and uniquely identify the intended argument behavior, as reported in #1215. |
|
@Adityakk9031 Thanks for the AI summary, but my question was why you need this. If you're pinging each team member daily on this PR (and on #1248), surely it must be because you have an urgent need for it in a project of yours? Making such changes with an AI agent is fast, but reviewing and iterating on these takes time, energy, and a holistic understanding of the codebase. Please help us understand why we should prioritize your pull requests over what we're currently working on. |
|
@christophercrouzet Sir Thanks for the clarification. The root cause and suggested fix were already described in the issue, and I simply implemented that fix. Also, sorry for the earlier AI-generated summary the code change itself was written by me, not by agent |
|
Also Sorry for tagging multiple people I can be a little dumb sometimes. |
|
@Adityakk9031 No worries, and thanks for your willingness to contribute to the project. We're a small team with a lot on our plate, and when issues are already assigned to team members, it generally means we have them on our radar and plan to address them as part of our roadmap. Uncoordinated PRs for these issues can end up adding to our workload rather than reducing it. If you do open a PR and need it reviewed promptly, please help us understand why. For example, if it's blocking a project of yours or addresses a critical bug you're hitting. Without that context, it's difficult for us to justify prioritizing an external PR over our current work. We'll be updating our contribution guidelines to make this clearer for everyone, and we'll take a look at your PR when we get a chance. Thanks for your understanding! |
Issue:
Both
jax_kernel()
and
jax_callable()
use a registry dict to cache and deduplicate FFI wrappers. The key tuple used to look up the registry was missing parameters that affect the wrapper's behaviour — in_out_argnames in
jax_kernel()
, and in_out_argnames, stage_in_argnames, stage_out_argnames in
jax_callable()
. This meant calling either function with the same kernel but different argname configurations silently returned the first cached object with the wrong configuration.
Fix:
Added the missing parameters to the key tuples in both functions in
warp/_src/jax_experimental/ffi.py
. Since list is not hashable, they are converted to tuple before being added to the key. Two wrappers with different configurations now correctly produce different keys and are stored as separate objects in the registry.
Summary by CodeRabbit