Skip to content

Commit dad527b

Browse files
committed
fix(jax): include in_out_argnames and stage argnames in FFI registry cache keys
Signed-off-by: Aditya kumar singh <143548997+Adityakk9031@users.noreply.github.com>
1 parent 46b1773 commit dad527b

File tree

1 file changed

+8
-0
lines changed
  • warp/_src/jax_experimental

1 file changed

+8
-0
lines changed

warp/_src/jax_experimental/ffi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,13 +1203,15 @@ def jax_kernel(
12031203
hashable_launch_dims = launch_dims
12041204

12051205
if not enable_backward:
1206+
hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None
12061207
key = (
12071208
kernel.func,
12081209
kernel.sig,
12091210
num_outputs,
12101211
vmap_method,
12111212
hashable_launch_dims,
12121213
hashable_output_dims,
1214+
hashable_in_out,
12131215
module_preload_mode,
12141216
has_side_effect,
12151217
)
@@ -1548,12 +1550,18 @@ def jax_callable(
15481550
hashable_output_dims = output_dims
15491551

15501552
# Note: we don't include graph_cache_max in the key, it is applied below.
1553+
hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None
1554+
hashable_stage_in = tuple(stage_in_argnames) if stage_in_argnames is not None else None
1555+
hashable_stage_out = tuple(stage_out_argnames) if stage_out_argnames is not None else None
15511556
key = (
15521557
func,
15531558
num_outputs,
15541559
graph_mode,
15551560
vmap_method,
15561561
hashable_output_dims,
1562+
hashable_in_out,
1563+
hashable_stage_in,
1564+
hashable_stage_out,
15571565
module_preload_mode,
15581566
has_side_effect,
15591567
)

0 commit comments

Comments
 (0)