Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions warp/_src/jax_experimental/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,13 +1203,15 @@ def jax_kernel(
hashable_launch_dims = launch_dims

if not enable_backward:
hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None
key = (
kernel.func,
kernel.sig,
num_outputs,
vmap_method,
hashable_launch_dims,
hashable_output_dims,
hashable_in_out,
module_preload_mode,
has_side_effect,
)
Expand Down Expand Up @@ -1548,12 +1550,18 @@ def jax_callable(
hashable_output_dims = output_dims

# Note: we don't include graph_cache_max in the key, it is applied below.
hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None
hashable_stage_in = tuple(stage_in_argnames) if stage_in_argnames is not None else None
hashable_stage_out = tuple(stage_out_argnames) if stage_out_argnames is not None else None
key = (
func,
num_outputs,
graph_mode,
vmap_method,
hashable_output_dims,
hashable_in_out,
hashable_stage_in,
hashable_stage_out,
module_preload_mode,
has_side_effect,
)
Expand Down