Skip to content

Commit 08949bf

Browse files
Arech8Google-ML-Automation
authored andcommitted
Integrating changes from #379
FUTURE_COPYBARA_INTEGRATE_REVIEW=#379 from Arech8:pr1_enable_gluon 48a0ff7 PiperOrigin-RevId: 883218017
1 parent 195e416 commit 08949bf

4 files changed

Lines changed: 243 additions & 36 deletions

File tree

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
![PyPI version](https://img.shields.io/pypi/v/jax-triton)
44

5-
The `jax-triton` repository contains integrations between [JAX](https://github.com/jax-ml/jax) and [Triton](https://github.com/openai/triton).
5+
The `jax-triton` repository contains integrations between [JAX](https://github.com/jax-ml/jax)
6+
and [Triton](https://github.com/openai/triton), including support for Gluon dialect.
67

78
Documentation can be found [here](https://jax-ml.github.io/jax-triton).
89

@@ -26,7 +27,7 @@ def add_kernel(
2627
y_ptr, # are input
2728
length, # arguments.
2829
output_ptr, # Implicit output argument goes after inputs.
29-
block_size: tl.constexpr, # Constexpr params goes the last.
30+
block_size: tl.constexpr, # Constexpr params go last.
3031
):
3132
"""Adds two vectors output = x + y."""
3233
pid = tl.program_id(axis=0)

jax_triton/triton_lib.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Module for calling Triton kernels from JAX."""
15+
"""Module for calling Triton or Triton.Gluon kernels from JAX."""
1616

1717
from __future__ import annotations
1818

@@ -46,13 +46,15 @@
4646
CAN_USE_TRITON = False
4747
try:
4848
import triton
49-
from triton.compiler import code_generator as code_gen
5049
from triton.compiler import compiler as tc
5150
import triton.language as tl
5251
from triton.runtime import autotuner
5352
import triton._C.libtriton as _triton
5453
import triton.backends.nvidia.compiler as cb
5554

55+
import triton.experimental.gluon._runtime as gl_runtime
56+
from triton.experimental.gluon import language as gl
57+
5658
CAN_USE_TRITON = True
5759
except ModuleNotFoundError:
5860
pass
@@ -115,7 +117,7 @@ def avals_to_layouts(avals):
115117
def get_triton_type(obj: Any) -> str:
116118
if isinstance(obj, (jax.core.ShapedArray, state.AbstractRef)):
117119
return f"*{_JAX_TO_TRITON_TYPE_MAP[obj.dtype]}"
118-
if isinstance(obj, tl.constexpr):
120+
if isinstance(obj, (tl.constexpr, gl.constexpr)):
119121
obj = obj.value
120122
if isinstance(obj, bool): # True == isinstance(True, int) !!!
121123
return "B"
@@ -160,10 +162,8 @@ def aval_size_bytes(aval):
160162
return np.dtype(aval.dtype).itemsize * aval.size
161163

162164

163-
def get_cuda_backend(device, compute_capability):
164-
target = cb.GPUTarget("cuda", compute_capability, 32)
165-
backend = cb.CUDABackend(target)
166-
return backend
165+
def make_gpu_target_cuda(device, compute_capability):
166+
return cb.GPUTarget("cuda", compute_capability, 32)
167167

168168

169169
_IS_HIPBackend_PATCHED = False
@@ -199,15 +199,13 @@ def fixed_is_within_2gb(arg):
199199
hb.HIPBackend.is_within_2gb = fixed_is_within_2gb
200200

201201

202-
def get_hip_backend(device, compute_capability):
202+
def make_gpu_target_hip(device, compute_capability):
203203
# TODO(Arech): remove _patch_hip_backend() once Triton releases a fix
204204
_patch_hip_backend()
205205

206206
arch = triton_kernel_call_lib.get_arch_details(device)
207207
arch = arch.split(":")[0]
208-
target = hb.GPUTarget("hip", arch, 64)
209-
backend = hb.HIPBackend(target)
210-
return backend
208+
return hb.GPUTarget("hip", arch, 64)
211209

212210

213211
@dataclasses.dataclass
@@ -358,7 +356,7 @@ def compile_ttir_to_hsaco_inplace(
358356

359357

360358
def get_or_create_triton_kernel(
361-
backend_init_func,
359+
make_gpu_target_func,
362360
platform,
363361
fn,
364362
arg_dtypes,
@@ -385,7 +383,8 @@ def get_or_create_triton_kernel(
385383
if num_ctas > 1 and compute_capability < 90:
386384
raise ValueError("num_ctas > 1 unsupported before Hopper.")
387385

388-
backend = backend_init_func(device, compute_capability)
386+
gpu_target = make_gpu_target_func(device, compute_capability)
387+
backend = triton.compiler.make_backend(gpu_target)
389388

390389
signature = {fn.arg_names[i]: v for i, v in enumerate(arg_dtypes)}
391390
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
@@ -470,16 +469,15 @@ def get_or_create_triton_kernel(
470469
backend.load_dialects(context)
471470
codegen_fns = backend.get_codegen_implementation(options)
472471

473-
module = code_gen.ast_to_ttir(
474-
fn,
475-
tc.ASTSource(
476-
fn, constexprs=constants, signature=signature, attrs=attrs
477-
),
478-
options=options,
479-
codegen_fns=codegen_fns,
480-
context=context,
481-
module_map=backend.get_module_map(),
472+
real_ASTSource = (
473+
gl_runtime.GluonASTSource
474+
if isinstance(fn, gl_runtime.GluonJITFunction)
475+
else tc.ASTSource
482476
)
477+
module = real_ASTSource(
478+
fn, constexprs=constants, signature=signature, attrs=attrs
479+
).make_ir(gpu_target, options, codegen_fns, backend.get_module_map(), context)
480+
483481
ttir = str(module)
484482

485483
compilation_result = compile_ttir_inplace(
@@ -529,7 +527,7 @@ def get_or_create_triton_kernel(
529527

530528

531529
def triton_kernel_call_lowering(
532-
backend_init_func,
530+
make_gpu_target_func,
533531
ctx,
534532
*array_args,
535533
fn,
@@ -547,16 +545,21 @@ def triton_kernel_call_lowering(
547545
zeroed_outputs,
548546
debug,
549547
serialized_metadata,
550-
**metaparams,
548+
metaparams: tuple[tuple[str, Any], ...],
551549
):
550+
# we have to pass metaparams dictionary as a tuple to allow hashing necessary for
551+
# lowering via xla_primitive_callable()
552+
assert isinstance(metaparams, tuple), "metaparams must be tuple[tuple[str, Any], ...]"
553+
metaparams = dict(metaparams) # wil crash if tuple format is incompatible
554+
552555
kernel_call_name = name
553556
args = list(ctx.avals_in)
554557
arg_dtypes = list(map(get_triton_type, ctx.avals_in))
555558
for idx, dtype, v in scalar_args:
556559
args.insert(idx, v)
557560
arg_dtypes.insert(idx, dtype)
558561
# Extract only the output avals not referenced in the input_output_aliases mapping.
559-
assert isinstance(input_output_aliases, tuple)
562+
assert isinstance(input_output_aliases, tuple), "input_output_aliases must be a tuple"
560563
input_output_aliases = dict(input_output_aliases)
561564
strictly_out_avals = [
562565
aval
@@ -622,9 +625,9 @@ def prune_configs(configs, named_args, **kwargs):
622625
configs = updated_configs
623626
fn = fn.fn
624627

625-
if not isinstance(fn, triton.JITFunction):
628+
if not isinstance(fn, (triton.JITFunction, gl_runtime.GluonJITFunction)):
626629
raise ValueError(
627-
"`kernel` must be a Triton `JITFunction`, `Heuristics` or `Autotuner`."
630+
"`kernel` must be a Triton `JITFunction`, `GluonJITFunction`, `Heuristics` or `Autotuner`."
628631
)
629632

630633
output2input = {v: k for k, v in input_output_aliases.items()}
@@ -664,7 +667,7 @@ def prune_configs(configs, named_args, **kwargs):
664667
kernel_calls = []
665668
for params in config_params:
666669
kernel, specialization_attr = get_or_create_triton_kernel(
667-
backend_init_func,
670+
make_gpu_target_func,
668671
ctx.module_context.platforms[0],
669672
fn,
670673
arg_dtypes,
@@ -739,13 +742,13 @@ def prune_configs(configs, named_args, **kwargs):
739742

740743
mlir.register_lowering(
741744
triton_kernel_call_p,
742-
functools.partial(triton_kernel_call_lowering, get_cuda_backend),
745+
functools.partial(triton_kernel_call_lowering, make_gpu_target_cuda),
743746
platform="cuda",
744747
)
745748

746749
mlir.register_lowering(
747750
triton_kernel_call_p,
748-
functools.partial(triton_kernel_call_lowering, get_hip_backend),
751+
functools.partial(triton_kernel_call_lowering, make_gpu_target_hip),
749752
platform="rocm",
750753
)
751754

@@ -791,6 +794,7 @@ def triton_call(
791794
*args: jax.Array | bool | int | float | np.float32,
792795
kernel: (
793796
triton.JITFunction
797+
| gl_runtime.GluonJITFunction
794798
| triton.runtime.Heuristics
795799
| triton.runtime.Autotuner
796800
),
@@ -865,7 +869,8 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
865869
Args:
866870
*args: Inputs for the Triton kernel.
867871
kernel: A Triton kernel (e.g. a function decorated with `triton.jit`). All
868-
static values should be annotated with `triton.language.constexpr`.
872+
static values should be annotated with `triton.language.constexpr` or
873+
`triton.experimental.gluon.language.constexpr`.
869874
out_shape: A `jax.ShapeDtypeStruct` (or something that has `.shape` and
870875
`.dtype` attributes) or a sequence thereof that specify the output(s) of
871876
the kernel. Pointers for each of the `jax.ShapeDtypeStruct`s in
@@ -880,14 +885,17 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
880885
indices. Providing a mapping will alias the corresponding buffers.
881886
zeroed_outputs: A sequence of indices, or a function returning a sequence of
882887
indices, for outputs that should be zeroed before the kernel is launched.
888+
Note that this also supports zeroing input-output (i.e. aliased through
889+
`input_output_aliases`) arguments that should be treated as outputs in this
890+
argument.
883891
num_warps: The number of warps used to execute the Triton kernel.
884892
num_stages: The number of stages emitted by the Triton compiler.
885893
num_ctas: The size of thread blocks per cluster to be used on GPUs with
886894
compute capabilities >= 9.0. It must be less or equal to 8.
887895
debug: Prints out intermediate IRs if True for debugging purposes.
888896
serialized_metadata: Arbitrary metadata that will be added into the
889897
serialized kernel call.
890-
**metaparams: Additional keyword arguments that will be provided to a `grid`
898+
metaparams: A dictionary of arguments that will be provided to a `grid`
891899
(if it is a function) and to the Triton kernel as `constexpr` arguments.
892900
893901
Returns:
@@ -934,6 +942,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
934942
zeroed_outputs=zeroed_outputs,
935943
debug=debug,
936944
serialized_metadata=serialized_metadata,
937-
**metaparams,
945+
metaparams=tuple(metaparams.items()),
938946
)
939947
return tree_util.tree_unflatten(out_tree, out_flat)

jax_triton/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version_info__ = (0, 3, 1)
15+
__version_info__ = (0, 4, 0)
1616
__version__ = ".".join(str(v) for v in __version_info__)

0 commit comments

Comments
 (0)