Skip to content

Commit d530f07

Browse files
committed
[FRONTEND] Implement ragged TMA descriptor functionality in Gluon
1 parent 80c6222 commit d530f07

3 files changed

Lines changed: 364 additions & 2 deletions

File tree

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import torch
2+
import pytest
3+
import triton
4+
import triton.language as tl
5+
from typing import Optional
6+
7+
from triton._internal_testing import is_hopper_or_newer
8+
from triton.tools.ragged_tma import (
9+
create_ragged_descriptor,
10+
create_ragged_descriptor_device_2d,
11+
create_ragged_descriptor_device_3d,
12+
load_ragged,
13+
store_ragged,
14+
)
15+
16+
17+
@triton.jit
18+
def example_load_store_kernel_host_desc(
19+
x_desc, y_desc, x_off, y_off, num_slices, ragged_dim: tl.constexpr, ndim: tl.constexpr
20+
):
21+
if ndim == 2:
22+
data = load_ragged(x_desc, x_off, num_slices, [0, 0], ragged_dim)
23+
store_ragged(y_desc, y_off, num_slices, [0, 0], data, ragged_dim)
24+
else:
25+
data = load_ragged(x_desc, x_off, num_slices, [0, 0, 0], ragged_dim)
26+
store_ragged(y_desc, y_off, num_slices, [0, 0, 0], data, ragged_dim)
27+
28+
29+
@triton.jit
30+
def example_load_store_kernel_device_desc_2d(
31+
x_ptr, y_ptr,
32+
x_off, y_off, num_slices,
33+
shape_0, shape_1,
34+
stride_0, stride_1,
35+
block_shape_0: tl.constexpr, block_shape_1: tl.constexpr,
36+
ragged_dim: tl.constexpr,
37+
):
38+
x_desc = create_ragged_descriptor_device_2d(
39+
x_ptr,
40+
shape_0, shape_1,
41+
stride_0, stride_1,
42+
block_shape_0, block_shape_1,
43+
ragged_dim,
44+
)
45+
y_desc = create_ragged_descriptor_device_2d(
46+
y_ptr,
47+
shape_0, shape_1,
48+
stride_0, stride_1,
49+
block_shape_0, block_shape_1,
50+
ragged_dim,
51+
)
52+
53+
data = load_ragged(x_desc, x_off, num_slices, [0, 0], ragged_dim)
54+
store_ragged(y_desc, y_off, num_slices, [0, 0], data, ragged_dim)
55+
56+
57+
@triton.jit
58+
def example_load_store_kernel_device_desc_3d(
59+
x_ptr, y_ptr,
60+
x_off, y_off, num_slices,
61+
shape_0, shape_1, shape_2,
62+
stride_0, stride_1, stride_2,
63+
block_shape_0: tl.constexpr, block_shape_1: tl.constexpr, block_shape_2: tl.constexpr,
64+
ragged_dim: tl.constexpr,
65+
):
66+
x_desc = create_ragged_descriptor_device_3d(
67+
x_ptr,
68+
shape_0, shape_1, shape_2,
69+
stride_0, stride_1, stride_2,
70+
block_shape_0, block_shape_1, block_shape_2,
71+
ragged_dim,
72+
)
73+
y_desc = create_ragged_descriptor_device_3d(
74+
y_ptr,
75+
shape_0, shape_1, shape_2,
76+
stride_0, stride_1, stride_2,
77+
block_shape_0, block_shape_1, block_shape_2,
78+
ragged_dim,
79+
)
80+
81+
data = load_ragged(x_desc, x_off, num_slices, [0, 0, 0], ragged_dim)
82+
store_ragged(y_desc, y_off, num_slices, [0, 0, 0], data, ragged_dim)
83+
84+
85+
def _generate_test_params():
86+
dtypes = ["float16", "float32"]
87+
modes = ["host", "device"]
88+
89+
params = []
90+
for dtype in dtypes:
91+
for mode in modes:
92+
# 2D tensors: only ragged_dim=0 is valid
93+
params.append((dtype, mode, 2, 0))
94+
# 3D tensors: ragged_dim=0 and ragged_dim=1 are valid
95+
params.append((dtype, mode, 3, 0))
96+
params.append((dtype, mode, 3, 1))
97+
98+
return params
99+
100+
101+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
102+
@pytest.mark.parametrize(
103+
"dtype_name,descriptor_mode,ndim,ragged_dim", _generate_test_params()
104+
)
105+
def test_ragged_tma(dtype_name, descriptor_mode, ndim, ragged_dim):
106+
107+
torch_dtype = getattr(torch, dtype_name)
108+
109+
if ndim == 2:
110+
shape = [128, 80]
111+
strides = [80, 1]
112+
block_shape = [32, 128]
113+
else: # ndim == 3
114+
if ragged_dim == 0:
115+
shape = [64, 32, 32]
116+
strides = [32 * 32, 32, 1]
117+
block_shape = [16, 16, 32]
118+
else: # ragged_dim == 1
119+
shape = [64, 32, 32]
120+
strides = [32 * 32, 32, 1]
121+
block_shape = [32, 16, 32]
122+
123+
src = torch.ones(shape, dtype=torch_dtype, device="cuda")
124+
dst = torch.zeros(shape, dtype=torch_dtype, device="cuda")
125+
126+
num_slices = min(block_shape[ragged_dim] - 1, shape[ragged_dim] // 3)
127+
x_off = 0
128+
y_off = (shape[ragged_dim] - num_slices) // 2
129+
130+
def alloc_fn(size: int, align: int, stream: Optional[int]):
131+
return torch.empty(size, dtype=torch.int8, device="cuda")
132+
133+
triton.set_allocator(alloc_fn)
134+
135+
if descriptor_mode == "host":
136+
x_desc = create_ragged_descriptor(src, block_shape, ragged_dim)
137+
y_desc = create_ragged_descriptor(dst, block_shape, ragged_dim)
138+
139+
example_load_store_kernel_host_desc[(1,)](
140+
x_desc,
141+
y_desc,
142+
x_off,
143+
y_off,
144+
num_slices,
145+
ragged_dim,
146+
ndim,
147+
)
148+
else:
149+
if ndim == 2:
150+
example_load_store_kernel_device_desc_2d[(1,)](
151+
src,
152+
dst,
153+
x_off,
154+
y_off,
155+
num_slices,
156+
shape[0], shape[1],
157+
strides[0], strides[1],
158+
block_shape[0], block_shape[1],
159+
ragged_dim,
160+
)
161+
else: # ndim == 3
162+
example_load_store_kernel_device_desc_3d[(1,)](
163+
src,
164+
dst,
165+
x_off,
166+
y_off,
167+
num_slices,
168+
shape[0], shape[1], shape[2],
169+
strides[0], strides[1], strides[2],
170+
block_shape[0], block_shape[1], block_shape[2],
171+
ragged_dim,
172+
)
173+
174+
if ragged_dim == 0:
175+
if ndim == 2:
176+
before = dst[:y_off, : block_shape[1]]
177+
copied = dst[y_off : y_off + num_slices, : block_shape[1]]
178+
after = dst[y_off + num_slices :, : block_shape[1]]
179+
else: # ndim == 3
180+
before = dst[:y_off, : block_shape[1], : block_shape[2]]
181+
copied = dst[y_off : y_off + num_slices, : block_shape[1], : block_shape[2]]
182+
after = dst[y_off + num_slices :, : block_shape[1], : block_shape[2]]
183+
else: # ragged_dim == 1
184+
before = dst[: block_shape[0], :y_off, : block_shape[2]]
185+
copied = dst[: block_shape[0], y_off : y_off + num_slices, : block_shape[2]]
186+
after = dst[: block_shape[0], y_off + num_slices :, : block_shape[2]]
187+
188+
res0 = torch.all(before == 0.0).item()
189+
res1 = torch.all(copied == 1.0).item()
190+
res2 = torch.all(after == 0.0).item()
191+
192+
assert [res0, res1, res2] == [
193+
True,
194+
True,
195+
True,
196+
], f"Failed for {ndim}D {descriptor_mode} mode ragged_dim={ragged_dim}: before={res0}, copied={res1}, after={res2}"
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from triton.experimental import gluon
2+
from triton.experimental.gluon import language as ttgl
3+
from triton.experimental.gluon.language._standard import _import_from_triton
4+
from triton.experimental.gluon.language.nvidia.hopper import tma
5+
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
6+
7+
import triton.tools.ragged_tma as tl_ragged
8+
9+
# fmt: off
10+
11+
def create_ragged_descriptor_host(T, block_shape, layout, ragged_dim=0):
12+
triton_desc = tl_ragged.create_ragged_descriptor(T, block_shape, ragged_dim)
13+
return TensorDescriptor(
14+
triton_desc.base,
15+
triton_desc.shape,
16+
triton_desc.strides,
17+
triton_desc.block_shape,
18+
layout,
19+
padding=triton_desc.padding
20+
)
21+
22+
23+
_compute_ragged_descriptor_params_2d = _import_from_triton(tl_ragged._compute_ragged_descriptor_params_2d)
24+
_compute_ragged_descriptor_params_3d = _import_from_triton(tl_ragged._compute_ragged_descriptor_params_3d)
25+
26+
@gluon.jit
27+
def create_ragged_descriptor_device_2d(
28+
base_ptr,
29+
shape_0, shape_1,
30+
stride_0, stride_1: ttgl.constexpr,
31+
block_shape_0: ttgl.constexpr, block_shape_1: ttgl.constexpr,
32+
layout,
33+
ragged_dim: ttgl.constexpr
34+
):
35+
shape, stride = _compute_ragged_descriptor_params_2d(
36+
shape_0, shape_1,
37+
stride_0, stride_1,
38+
ragged_dim
39+
)
40+
return tma.make_tensor_descriptor(
41+
base_ptr,
42+
shape=shape,
43+
strides=[stride[0], stride[1], stride[2], stride_1],
44+
block_shape=[1, 1, block_shape_0, block_shape_1],
45+
layout=layout,
46+
)
47+
48+
49+
@gluon.jit
50+
def create_ragged_descriptor_device_3d(
51+
base_ptr,
52+
shape_0, shape_1, shape_2,
53+
stride_0, stride_1, stride_2: ttgl.constexpr,
54+
block_shape_0: ttgl.constexpr, block_shape_1: ttgl.constexpr, block_shape_2: ttgl.constexpr,
55+
layout,
56+
ragged_dim: ttgl.constexpr
57+
):
58+
shape, stride = _compute_ragged_descriptor_params_3d(
59+
shape_0, shape_1, shape_2,
60+
stride_0, stride_1, stride_2,
61+
ragged_dim
62+
)
63+
return tma.make_tensor_descriptor(
64+
base_ptr,
65+
shape=shape,
66+
strides=[stride[0], stride[1], stride[2], stride[3], stride_2],
67+
block_shape=[1, 1, block_shape_0, block_shape_1, block_shape_2],
68+
layout=layout,
69+
)
70+
71+
72+
_to_ragged_indices = _import_from_triton(tl_ragged.to_ragged_indices)
73+
74+
75+
@gluon.jit
76+
def to_ragged_coords(slice_off, slice_size, coords, ragged_dim: ttgl.constexpr):
77+
c0, c1, c2 = _to_ragged_indices(slice_off, slice_size, coords[ragged_dim])
78+
return [c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:]

python/triton/tools/ragged_tma.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,104 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0):
4545
return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
4646

4747

48+
@triton.jit
49+
def _compute_ragged_descriptor_params_2d(
50+
shape_0, shape_1,
51+
stride_0, stride_1: tl.constexpr,
52+
ragged_dim: tl.constexpr
53+
):
54+
tl.static_assert(
55+
ragged_dim < 1,
56+
"Using last dim as ragged dim is not supported"
57+
)
58+
59+
max_int: tl.constexpr = 0x7fff0000
60+
billion: tl.constexpr = 0x40000000
61+
two_to_34 = tl.to_tensor(2**34)
62+
return (
63+
[max_int, max_int, billion, shape_1],
64+
[two_to_34 - stride_0, stride_0, stride_0, stride_1],
65+
)
66+
67+
68+
@triton.jit
69+
def _compute_ragged_descriptor_params_3d(
70+
shape_0, shape_1, shape_2,
71+
stride_0, stride_1, stride_2: tl.constexpr,
72+
ragged_dim: tl.constexpr
73+
):
74+
tl.static_assert(
75+
ragged_dim < 2,
76+
"Using last dim as ragged dim is not supported"
77+
)
78+
79+
max_int: tl.constexpr = 0x7fff0000
80+
billion: tl.constexpr = 0x40000000
81+
two_to_34 = tl.to_tensor(2**34)
82+
if ragged_dim == 0:
83+
return (
84+
[max_int, max_int, billion, shape_1, shape_2],
85+
[two_to_34 - stride_0, stride_0, stride_0, stride_1, stride_2],
86+
)
87+
else:
88+
return (
89+
[max_int, max_int, shape_0, billion, shape_2],
90+
[two_to_34 - stride_1, stride_1, stride_0, stride_1, stride_2],
91+
)
92+
93+
94+
@triton.jit
95+
def create_ragged_descriptor_device_2d(
96+
base_ptr,
97+
shape_0, shape_1,
98+
stride_0, stride_1: tl.constexpr,
99+
block_shape_0: tl.constexpr, block_shape_1: tl.constexpr,
100+
ragged_dim: tl.constexpr
101+
):
102+
shape, stride = _compute_ragged_descriptor_params_2d(
103+
shape_0, shape_1,
104+
stride_0, stride_1,
105+
ragged_dim
106+
)
107+
one: tl.constexpr = 1
108+
return tl.make_tensor_descriptor(
109+
base_ptr,
110+
shape=shape,
111+
strides=[stride[0], stride[1], stride[2], stride_1],
112+
block_shape=[one, one, block_shape_0, block_shape_1],
113+
)
114+
115+
116+
@triton.jit
117+
def create_ragged_descriptor_device_3d(
118+
base_ptr,
119+
shape_0, shape_1, shape_2,
120+
stride_0, stride_1, stride_2: tl.constexpr,
121+
block_shape_0: tl.constexpr, block_shape_1: tl.constexpr, block_shape_2: tl.constexpr,
122+
ragged_dim: tl.constexpr
123+
):
124+
shape, stride = _compute_ragged_descriptor_params_3d(
125+
shape_0, shape_1, shape_2,
126+
stride_0, stride_1, stride_2,
127+
ragged_dim
128+
)
129+
one: tl.constexpr = 1
130+
return tl.make_tensor_descriptor(
131+
base_ptr,
132+
shape=shape,
133+
strides=[stride[0], stride[1], stride[2], stride[3], stride_2],
134+
block_shape=[one, one, block_shape_0, block_shape_1, block_shape_2],
135+
)
136+
137+
48138
@triton.jit
49139
def to_ragged_indices(slice_off, slice_size, row):
50140
"""
51141
Helper function for load_ragged and store_ragged.
52142
"""
53-
54143
billion = 0x40000000 # == 2**30
55144
x = billion - slice_size + row
56145
y = slice_off + slice_size
57-
58146
return billion, y, x
59147

60148

0 commit comments

Comments
 (0)