1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """Module for calling Triton kernels from JAX."""
15+ """Module for calling Triton or Triton.Gluon kernels from JAX."""
1616
1717from __future__ import annotations
1818
4646CAN_USE_TRITON = False
4747try :
4848 import triton
49- from triton .compiler import code_generator as code_gen
5049 from triton .compiler import compiler as tc
5150 import triton .language as tl
5251 from triton .runtime import autotuner
5352 import triton ._C .libtriton as _triton
5453 import triton .backends .nvidia .compiler as cb
5554
55+ import triton .experimental .gluon ._runtime as gl_runtime
56+ from triton .experimental .gluon import language as gl
57+
5658 CAN_USE_TRITON = True
5759except ModuleNotFoundError :
5860 pass
@@ -115,7 +117,7 @@ def avals_to_layouts(avals):
115117def get_triton_type (obj : Any ) -> str :
116118 if isinstance (obj , (jax .core .ShapedArray , state .AbstractRef )):
117119 return f"*{ _JAX_TO_TRITON_TYPE_MAP [obj .dtype ]} "
118- if isinstance (obj , tl .constexpr ):
120+ if isinstance (obj , ( tl .constexpr , gl . constexpr ) ):
119121 obj = obj .value
120122 if isinstance (obj , bool ): # True == isinstance(True, int) !!!
121123 return "B"
@@ -160,10 +162,8 @@ def aval_size_bytes(aval):
160162 return np .dtype (aval .dtype ).itemsize * aval .size
161163
162164
163- def get_cuda_backend (device , compute_capability ):
164- target = cb .GPUTarget ("cuda" , compute_capability , 32 )
165- backend = cb .CUDABackend (target )
166- return backend
165+ def make_gpu_target_cuda (device , compute_capability ):
166+ return cb .GPUTarget ("cuda" , compute_capability , 32 )
167167
168168
169169_IS_HIPBackend_PATCHED = False
@@ -199,15 +199,13 @@ def fixed_is_within_2gb(arg):
199199 hb .HIPBackend .is_within_2gb = fixed_is_within_2gb
200200
201201
202- def get_hip_backend (device , compute_capability ):
202+ def make_gpu_target_hip (device , compute_capability ):
203203 # TODO(Arech): remove _patch_hip_backend() once Triton releases a fix
204204 _patch_hip_backend ()
205205
206206 arch = triton_kernel_call_lib .get_arch_details (device )
207207 arch = arch .split (":" )[0 ]
208- target = hb .GPUTarget ("hip" , arch , 64 )
209- backend = hb .HIPBackend (target )
210- return backend
208+ return hb .GPUTarget ("hip" , arch , 64 )
211209
212210
213211@dataclasses .dataclass
@@ -358,7 +356,7 @@ def compile_ttir_to_hsaco_inplace(
358356
359357
360358def get_or_create_triton_kernel (
361- backend_init_func ,
359+ make_gpu_target_func ,
362360 platform ,
363361 fn ,
364362 arg_dtypes ,
@@ -385,7 +383,8 @@ def get_or_create_triton_kernel(
385383 if num_ctas > 1 and compute_capability < 90 :
386384 raise ValueError ("num_ctas > 1 unsupported before Hopper." )
387385
388- backend = backend_init_func (device , compute_capability )
386+ gpu_target = make_gpu_target_func (device , compute_capability )
387+ backend = triton .compiler .make_backend (gpu_target )
389388
390389 signature = {fn .arg_names [i ]: v for i , v in enumerate (arg_dtypes )}
391390 # TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
@@ -470,16 +469,15 @@ def get_or_create_triton_kernel(
470469 backend .load_dialects (context )
471470 codegen_fns = backend .get_codegen_implementation (options )
472471
473- module = code_gen .ast_to_ttir (
474- fn ,
475- tc .ASTSource (
476- fn , constexprs = constants , signature = signature , attrs = attrs
477- ),
478- options = options ,
479- codegen_fns = codegen_fns ,
480- context = context ,
481- module_map = backend .get_module_map (),
472+ real_ASTSource = (
473+ gl_runtime .GluonASTSource
474+ if isinstance (fn , gl_runtime .GluonJITFunction )
475+ else tc .ASTSource
482476 )
477+ module = real_ASTSource (
478+ fn , constexprs = constants , signature = signature , attrs = attrs
479+ ).make_ir (gpu_target , options , codegen_fns , backend .get_module_map (), context )
480+
483481 ttir = str (module )
484482
485483 compilation_result = compile_ttir_inplace (
@@ -529,7 +527,7 @@ def get_or_create_triton_kernel(
529527
530528
531529def triton_kernel_call_lowering (
532- backend_init_func ,
530+ make_gpu_target_func ,
533531 ctx ,
534532 * array_args ,
535533 fn ,
@@ -547,16 +545,21 @@ def triton_kernel_call_lowering(
547545 zeroed_outputs ,
548546 debug ,
549547 serialized_metadata ,
550- ** metaparams ,
548+ metaparams : tuple [ tuple [ str , Any ], ...] ,
551549):
550+ # we have to pass metaparams dictionary as a tuple to allow hashing necessary for
551+ # lowering via xla_primitive_callable()
552+ assert isinstance (metaparams , tuple ), "metaparams must be tuple[tuple[str, Any], ...]"
553+ metaparams = dict (metaparams ) # wil crash if tuple format is incompatible
554+
552555 kernel_call_name = name
553556 args = list (ctx .avals_in )
554557 arg_dtypes = list (map (get_triton_type , ctx .avals_in ))
555558 for idx , dtype , v in scalar_args :
556559 args .insert (idx , v )
557560 arg_dtypes .insert (idx , dtype )
558561 # Extract only the output avals not referenced in the input_output_aliases mapping.
559- assert isinstance (input_output_aliases , tuple )
562+ assert isinstance (input_output_aliases , tuple ), "input_output_aliases must be a tuple"
560563 input_output_aliases = dict (input_output_aliases )
561564 strictly_out_avals = [
562565 aval
@@ -622,9 +625,9 @@ def prune_configs(configs, named_args, **kwargs):
622625 configs = updated_configs
623626 fn = fn .fn
624627
625- if not isinstance (fn , triton .JITFunction ):
628+ if not isinstance (fn , ( triton .JITFunction , gl_runtime . GluonJITFunction ) ):
626629 raise ValueError (
627- "`kernel` must be a Triton `JITFunction`, `Heuristics` or `Autotuner`."
630+ "`kernel` must be a Triton `JITFunction`, `GluonJITFunction`, ` Heuristics` or `Autotuner`."
628631 )
629632
630633 output2input = {v : k for k , v in input_output_aliases .items ()}
@@ -664,7 +667,7 @@ def prune_configs(configs, named_args, **kwargs):
664667 kernel_calls = []
665668 for params in config_params :
666669 kernel , specialization_attr = get_or_create_triton_kernel (
667- backend_init_func ,
670+ make_gpu_target_func ,
668671 ctx .module_context .platforms [0 ],
669672 fn ,
670673 arg_dtypes ,
@@ -739,13 +742,13 @@ def prune_configs(configs, named_args, **kwargs):
739742
740743mlir .register_lowering (
741744 triton_kernel_call_p ,
742- functools .partial (triton_kernel_call_lowering , get_cuda_backend ),
745+ functools .partial (triton_kernel_call_lowering , make_gpu_target_cuda ),
743746 platform = "cuda" ,
744747)
745748
746749mlir .register_lowering (
747750 triton_kernel_call_p ,
748- functools .partial (triton_kernel_call_lowering , get_hip_backend ),
751+ functools .partial (triton_kernel_call_lowering , make_gpu_target_hip ),
749752 platform = "rocm" ,
750753)
751754
@@ -791,6 +794,7 @@ def triton_call(
791794 * args : jax .Array | bool | int | float | np .float32 ,
792795 kernel : (
793796 triton .JITFunction
797+ | gl_runtime .GluonJITFunction
794798 | triton .runtime .Heuristics
795799 | triton .runtime .Autotuner
796800 ),
@@ -865,7 +869,8 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
865869 Args:
866870 *args: Inputs for the Triton kernel.
867871 kernel: A Triton kernel (e.g. a function decorated with `triton.jit`). All
868- static values should be annotated with `triton.language.constexpr`.
872+ static values should be annotated with `triton.language.constexpr` or
873+ `triton.experimental.gluon.language.constexpr`.
869874 out_shape: A `jax.ShapeDtypeStruct` (or something that has `.shape` and
870875 `.dtype` attributes) or a sequence thereof that specify the output(s) of
871876 the kernel. Pointers for each of the `jax.ShapeDtypeStruct`s in
@@ -880,14 +885,17 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
880885 indices. Providing a mapping will alias the corresponding buffers.
881886 zeroed_outputs: A sequence of indices, or a function returning a sequence of
882887 indices, for outputs that should be zeroed before the kernel is launched.
888+ Note that this also supports zeroing input-output (i.e. aliased through
889+ `input_output_aliases`) arguments that should be treated as outputs in this
890+ argument.
883891 num_warps: The number of warps used to execute the Triton kernel.
884892 num_stages: The number of stages emitted by the Triton compiler.
885893 num_ctas: The size of thread blocks per cluster to be used on GPUs with
886894 compute capabilities >= 9.0. It must be less or equal to 8.
887895 debug: Prints out intermediate IRs if True for debugging purposes.
888896 serialized_metadata: Arbitrary metadata that will be added into the
889897 serialized kernel call.
890- ** metaparams: Additional keyword arguments that will be provided to a `grid`
898+ metaparams: A dictionary of arguments that will be provided to a `grid`
891899 (if it is a function) and to the Triton kernel as `constexpr` arguments.
892900
893901 Returns:
@@ -934,6 +942,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
934942 zeroed_outputs = zeroed_outputs ,
935943 debug = debug ,
936944 serialized_metadata = serialized_metadata ,
937- ** metaparams ,
945+ metaparams = tuple ( metaparams . items ()) ,
938946 )
939947 return tree_util .tree_unflatten (out_tree , out_flat )
0 commit comments