Skip to content

Commit 195e416

Browse files
Merge pull request #378 from Arech8:pr0_improve_inout
PiperOrigin-RevId: 881037610
2 parents 7fea837 + b26e6d8 commit 195e416

3 files changed

Lines changed: 230 additions & 41 deletions

File tree

README.md

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ import triton.language as tl
2222

2323
@triton.jit
2424
def add_kernel(
25-
x_ptr,
26-
y_ptr,
27-
length,
28-
output_ptr,
29-
block_size: tl.constexpr,
25+
x_ptr, # First 3 arguments
26+
y_ptr, # are input
27+
length, # arguments.
28+
output_ptr, # Implicit output argument goes after inputs.
29+
block_size: tl.constexpr, # Constexpr params goes the last.
3030
):
31-
"""Adds two vectors."""
31+
"""Adds two vectors output = x + y."""
3232
pid = tl.program_id(axis=0)
3333
block_start = pid * block_size
3434
offsets = block_start + tl.arange(0, block_size)
@@ -47,43 +47,87 @@ import jax.numpy as jnp
4747
import jax_triton as jt
4848

4949
def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
50-
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
5150
block_size = 8
5251
return jt.triton_call(
53-
x,
54-
y,
55-
x.size,
52+
x, # Kernel's input arguments are the first
53+
y, # in jt.triton_call(). The output argument
54+
x.size, # is passed implicitly.
5655
kernel=add_kernel,
57-
out_shape=out_shape,
56+
out_shape=x,
5857
grid=(x.size // block_size,),
59-
block_size=block_size)
58+
block_size=block_size # Constexpr params are passed as kwargs
59+
)
6060

6161
x_val = jnp.arange(8)
6262
y_val = jnp.arange(8, 16)
6363
print(add(x_val, y_val))
6464
print(jax.jit(add)(x_val, y_val))
6565
```
6666

67+
One could also use input-output parameters for kernels:
68+
69+
```python
70+
71+
@triton.jit
72+
def add_inplace_y_kernel(
73+
x_ptr, # input vector
74+
y_inout_ptr, # explicit in-out vector (could be anywhere)
75+
length,
76+
block_size: tl.constexpr,
77+
):
78+
"""Adds two vectors output = x + y."""
79+
pid = tl.program_id(axis=0)
80+
block_start = pid * block_size
81+
offsets = block_start + tl.arange(0, block_size)
82+
mask = offsets < length
83+
x = tl.load(x_ptr + offsets, mask=mask)
84+
y = tl.load(y_inout_ptr + offsets, mask=mask)
85+
output = x + y
86+
tl.store(y_inout_ptr + offsets, output, mask=mask)
87+
88+
89+
from functools import partial
90+
91+
# jitting or jitting with donation isn't mandatory, but makes invocation more efficient.
92+
# Otherwise XLA would have to make a copy of each non-donated in-out argument before
93+
# calling a kernel, since JAX arrays by default are immutable.
94+
@partial(jax.jit, donate_argnames="y")
95+
def add_inplace_y(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
96+
block_size = 8
97+
return jt.triton_call(
98+
x,
99+
y, # explicit in-out argument
100+
x.size,
101+
kernel=add_inplace_y_kernel,
102+
input_output_aliases={1: 0}, # input arg idx 1 (y) is the first output arg
103+
out_shape=x,
104+
grid=(x.size // block_size,),
105+
block_size=block_size)
106+
107+
x_val = jnp.arange(8)
108+
y_val = jnp.arange(8, 16)
109+
print(add_inplace_y(x_val, y_val))
110+
```
111+
67112
See [the examples
68113
directory](https://github.com/jax-ml/jax-triton/tree/main/examples), especially
69114
[fused_attention.py](https://github.com/jax-ml/jax-triton/blob/main/examples/fused_attention.py)
70115
and [the fused attention
71116
ipynb](https://github.com/jax-ml/jax-triton/blob/main/examples/JAX_%2B_Triton_Flash_Attention.ipynb).
72117

118+
Some other use-cases are also covered in [tests](https://github.com/jax-ml/jax-triton/tree/main/tests).
119+
73120
## Installation
74121

75122
```bash
76123
$ pip install jax-triton
77124
```
78125

79-
Make sure you have a CUDA-compatible `jax` installed. For example you could run:
126+
Make sure you have a CUDA- or ROCm- compatible `jax` installed. For example you could run:
80127
```bash
81128
$ pip install "jax[cuda12]"
82129
```
83130

84-
`jax-triton` currently requires building the latest version of `triton`
85-
[from source](https://triton-lang.org/main/getting-started/installation.html#from-source).
86-
87131
## Development
88132

89133
To develop `jax-triton`, you can clone the repo with:

jax_triton/triton_lib.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from jax._src import state
3535
from jax._src import util
3636
from jax._src.lib import gpu_triton as triton_kernel_call_lib
37-
import jax.dlpack
3837
import jax.extend as jex
3938
from jax.interpreters import ad
4039
from jax.interpreters import batching
@@ -435,6 +434,19 @@ def get_or_create_triton_kernel(
435434
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)
436435

437436
if kernel is None:
437+
# First, check that the kernel signature and the reconstructed signature have the
438+
# same number of parameters. A mismatch can occur due to differences in
439+
# `triton_call(input_output_aliases=)` handling between jax-triton versions.
440+
if len(fn.signature.parameters) != len(signature):
441+
raise TypeError(
442+
f"Number of parameters in the kernel '{fn}' signature "
443+
f"({len(fn.signature.parameters)}: {fn.signature}) "
444+
f"does not match reconstructed signature ({len(signature)}: {signature}). "
445+
"If the kernel was working on an older version of jax-triton and its "
446+
"triton_call() launcher uses `input_output_aliases` argument, note that "
447+
"implicit output arguments are no longer required for aliased args."
448+
)
449+
438450
opts = {
439451
"num_warps": num_warps,
440452
"num_stages": num_stages,
@@ -543,8 +555,17 @@ def triton_kernel_call_lowering(
543555
for idx, dtype, v in scalar_args:
544556
args.insert(idx, v)
545557
arg_dtypes.insert(idx, dtype)
546-
args.extend(ctx.avals_out)
547-
arg_dtypes.extend(map(get_triton_type, ctx.avals_out))
558+
# Extract only the output avals not referenced in the input_output_aliases mapping.
559+
assert isinstance(input_output_aliases, tuple)
560+
input_output_aliases = dict(input_output_aliases)
561+
strictly_out_avals = [
562+
aval
563+
for i, aval in enumerate(ctx.avals_out)
564+
if i not in input_output_aliases.values()
565+
]
566+
args.extend(strictly_out_avals)
567+
arg_dtypes.extend(map(get_triton_type, strictly_out_avals))
568+
548569
named_args = dict(unsafe_zip(fn.arg_names, args))
549570

550571
if isinstance(fn, autotuner.Autotuner):
@@ -606,6 +627,10 @@ def prune_configs(configs, named_args, **kwargs):
606627
"`kernel` must be a Triton `JITFunction`, `Heuristics` or `Autotuner`."
607628
)
608629

630+
output2input = {v: k for k, v in input_output_aliases.items()}
631+
if len(output2input) != len(input_output_aliases):
632+
raise ValueError("input_output_aliases must be a bijection")
633+
609634
outputs_offset = len(ctx.avals_in) + len(scalar_args)
610635
config_params = []
611636
for config in configs:
@@ -616,9 +641,13 @@ def prune_configs(configs, named_args, **kwargs):
616641
if callable(zeroed_outputs):
617642
config_zeroed_outputs = config_zeroed_outputs(config_metaparams)
618643

644+
# zeroed_params_with_sizes is a dict output_arg_idx -> aval_size_bytes
645+
# config_zeroed_outputs is output ordinal numbers
619646
zeroed_params_with_sizes = {
620-
i + outputs_offset: aval_size_bytes(ctx.avals_out[i])
621-
for i in sorted(config_zeroed_outputs)
647+
output2input[i] if i in output2input else i + outputs_offset: aval_size_bytes(
648+
ctx.avals_out[i]
649+
)
650+
for i in sorted(config_zeroed_outputs)
622651
}
623652

624653
config_params.append(
@@ -688,7 +717,7 @@ def prune_configs(configs, named_args, **kwargs):
688717
named_scalar_args = {fn.arg_names[i]: v for i, _, v in scalar_args}
689718
input_output_aliases_with_sizes = tuple(
690719
(input_idx, output_idx, aval_size_bytes(ctx.avals_in[input_idx]))
691-
for input_idx, output_idx in input_output_aliases
720+
for input_idx, output_idx in input_output_aliases.items()
692721
)
693722
kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall(
694723
f"{kernel_call_name} ({fn.fn.__name__}) {named_scalar_args}",
@@ -703,7 +732,7 @@ def prune_configs(configs, named_args, **kwargs):
703732
custom_call_target_name,
704733
api_version=2,
705734
backend_config=zlib.compress(call_proto),
706-
operand_output_aliases=dict(input_output_aliases),
735+
operand_output_aliases=input_output_aliases,
707736
)
708737
return rule(ctx, *array_args)
709738

0 commit comments

Comments
 (0)