File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed
warp/_src/jax_experimental Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -1203,13 +1203,15 @@ def jax_kernel(
12031203 hashable_launch_dims = launch_dims
12041204
12051205 if not enable_backward :
1206+ hashable_in_out = tuple (in_out_argnames ) if in_out_argnames is not None else None
12061207 key = (
12071208 kernel .func ,
12081209 kernel .sig ,
12091210 num_outputs ,
12101211 vmap_method ,
12111212 hashable_launch_dims ,
12121213 hashable_output_dims ,
1214+ hashable_in_out ,
12131215 module_preload_mode ,
12141216 has_side_effect ,
12151217 )
@@ -1548,12 +1550,18 @@ def jax_callable(
15481550 hashable_output_dims = output_dims
15491551
15501552 # Note: we don't include graph_cache_max in the key, it is applied below.
1553+ hashable_in_out = tuple (in_out_argnames ) if in_out_argnames is not None else None
1554+ hashable_stage_in = tuple (stage_in_argnames ) if stage_in_argnames is not None else None
1555+ hashable_stage_out = tuple (stage_out_argnames ) if stage_out_argnames is not None else None
15511556 key = (
15521557 func ,
15531558 num_outputs ,
15541559 graph_mode ,
15551560 vmap_method ,
15561561 hashable_output_dims ,
1562+ hashable_in_out ,
1563+ hashable_stage_in ,
1564+ hashable_stage_out ,
15571565 module_preload_mode ,
15581566 has_side_effect ,
15591567 )
You can’t perform that action at this time.
0 commit comments