3434from jax ._src import state
3535from jax ._src import util
3636from jax ._src .lib import gpu_triton as triton_kernel_call_lib
37- import jax .dlpack
3837import jax .extend as jex
3938from jax .interpreters import ad
4039from jax .interpreters import batching
@@ -435,6 +434,19 @@ def get_or_create_triton_kernel(
435434 kernel = _COMPILED_KERNEL_CACHE .get (cache_key )
436435
437436 if kernel is None :
437+ # First, check that the kernel signature and the reconstructed signature have the
438+ # same number of parameters. A mismatch can occur due to differences in
439+ # `triton_call(input_output_aliases=)` handling between jax-triton versions.
440+ if len (fn .signature .parameters ) != len (signature ):
441+ raise TypeError (
442+ f"Number of parameters in the kernel '{ fn } ' signature "
443+ f"({ len (fn .signature .parameters )} : { fn .signature } ) "
444+ f"does not match reconstructed signature ({ len (signature )} : { signature } ). "
445+ "If the kernel was working on an older version of jax-triton and its "
446+ "triton_call() launcher uses `input_output_aliases` argument, note that "
447+ "implicit output arguments are no longer required for aliased args."
448+ )
449+
438450 opts = {
439451 "num_warps" : num_warps ,
440452 "num_stages" : num_stages ,
@@ -543,8 +555,17 @@ def triton_kernel_call_lowering(
543555 for idx , dtype , v in scalar_args :
544556 args .insert (idx , v )
545557 arg_dtypes .insert (idx , dtype )
546- args .extend (ctx .avals_out )
547- arg_dtypes .extend (map (get_triton_type , ctx .avals_out ))
558+ # Extract only the output avals not referenced in the input_output_aliases mapping.
559+ assert isinstance (input_output_aliases , tuple )
560+ input_output_aliases = dict (input_output_aliases )
561+ strictly_out_avals = [
562+ aval
563+ for i , aval in enumerate (ctx .avals_out )
564+ if i not in input_output_aliases .values ()
565+ ]
566+ args .extend (strictly_out_avals )
567+ arg_dtypes .extend (map (get_triton_type , strictly_out_avals ))
568+
548569 named_args = dict (unsafe_zip (fn .arg_names , args ))
549570
550571 if isinstance (fn , autotuner .Autotuner ):
@@ -606,6 +627,10 @@ def prune_configs(configs, named_args, **kwargs):
606627 "`kernel` must be a Triton `JITFunction`, `Heuristics` or `Autotuner`."
607628 )
608629
630+ output2input = {v : k for k , v in input_output_aliases .items ()}
631+ if len (output2input ) != len (input_output_aliases ):
632+ raise ValueError ("input_output_aliases must be a bijection" )
633+
609634 outputs_offset = len (ctx .avals_in ) + len (scalar_args )
610635 config_params = []
611636 for config in configs :
@@ -616,9 +641,13 @@ def prune_configs(configs, named_args, **kwargs):
616641 if callable (zeroed_outputs ):
617642 config_zeroed_outputs = config_zeroed_outputs (config_metaparams )
618643
644+ # zeroed_params_with_sizes is a dict output_arg_idx -> aval_size_bytes
645+ # config_zeroed_outputs is output ordinal numbers
619646 zeroed_params_with_sizes = {
620- i + outputs_offset : aval_size_bytes (ctx .avals_out [i ])
621- for i in sorted (config_zeroed_outputs )
647+ output2input [i ] if i in output2input else i + outputs_offset : aval_size_bytes (
648+ ctx .avals_out [i ]
649+ )
650+ for i in sorted (config_zeroed_outputs )
622651 }
623652
624653 config_params .append (
@@ -688,7 +717,7 @@ def prune_configs(configs, named_args, **kwargs):
688717 named_scalar_args = {fn .arg_names [i ]: v for i , _ , v in scalar_args }
689718 input_output_aliases_with_sizes = tuple (
690719 (input_idx , output_idx , aval_size_bytes (ctx .avals_in [input_idx ]))
691- for input_idx , output_idx in input_output_aliases
720+ for input_idx , output_idx in input_output_aliases . items ()
692721 )
693722 kernel_call = triton_kernel_call_lib .TritonAutotunedKernelCall (
694723 f"{ kernel_call_name } ({ fn .fn .__name__ } ) { named_scalar_args } " ,
@@ -703,7 +732,7 @@ def prune_configs(configs, named_args, **kwargs):
703732 custom_call_target_name ,
704733 api_version = 2 ,
705734 backend_config = zlib .compress (call_proto ),
706- operand_output_aliases = dict ( input_output_aliases ) ,
735+ operand_output_aliases = input_output_aliases ,
707736 )
708737 return rule (ctx , * array_args )
709738
0 commit comments