|
| 1 | +import torch |
| 2 | +import pytest |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +from triton._internal_testing import is_hopper_or_newer |
| 8 | +from triton.tools.ragged_tma import ( |
| 9 | + create_ragged_descriptor, |
| 10 | + create_ragged_descriptor_device_2d, |
| 11 | + create_ragged_descriptor_device_3d, |
| 12 | + load_ragged, |
| 13 | + store_ragged, |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +@triton.jit |
| 18 | +def example_load_store_kernel_host_desc( |
| 19 | + x_desc, y_desc, x_off, y_off, num_slices, ragged_dim: tl.constexpr, ndim: tl.constexpr |
| 20 | +): |
| 21 | + if ndim == 2: |
| 22 | + data = load_ragged(x_desc, x_off, num_slices, [0, 0], ragged_dim) |
| 23 | + store_ragged(y_desc, y_off, num_slices, [0, 0], data, ragged_dim) |
| 24 | + else: |
| 25 | + data = load_ragged(x_desc, x_off, num_slices, [0, 0, 0], ragged_dim) |
| 26 | + store_ragged(y_desc, y_off, num_slices, [0, 0, 0], data, ragged_dim) |
| 27 | + |
| 28 | + |
| 29 | +@triton.jit |
| 30 | +def example_load_store_kernel_device_desc_2d( |
| 31 | + x_ptr, y_ptr, |
| 32 | + x_off, y_off, num_slices, |
| 33 | + shape_0, shape_1, |
| 34 | + stride_0, stride_1, |
| 35 | + block_shape_0: tl.constexpr, block_shape_1: tl.constexpr, |
| 36 | + ragged_dim: tl.constexpr, |
| 37 | +): |
| 38 | + x_desc = create_ragged_descriptor_device_2d( |
| 39 | + x_ptr, |
| 40 | + shape_0, shape_1, |
| 41 | + stride_0, stride_1, |
| 42 | + block_shape_0, block_shape_1, |
| 43 | + ragged_dim, |
| 44 | + ) |
| 45 | + y_desc = create_ragged_descriptor_device_2d( |
| 46 | + y_ptr, |
| 47 | + shape_0, shape_1, |
| 48 | + stride_0, stride_1, |
| 49 | + block_shape_0, block_shape_1, |
| 50 | + ragged_dim, |
| 51 | + ) |
| 52 | + |
| 53 | + data = load_ragged(x_desc, x_off, num_slices, [0, 0], ragged_dim) |
| 54 | + store_ragged(y_desc, y_off, num_slices, [0, 0], data, ragged_dim) |
| 55 | + |
| 56 | + |
| 57 | +@triton.jit |
| 58 | +def example_load_store_kernel_device_desc_3d( |
| 59 | + x_ptr, y_ptr, |
| 60 | + x_off, y_off, num_slices, |
| 61 | + shape_0, shape_1, shape_2, |
| 62 | + stride_0, stride_1, stride_2, |
| 63 | + block_shape_0: tl.constexpr, block_shape_1: tl.constexpr, block_shape_2: tl.constexpr, |
| 64 | + ragged_dim: tl.constexpr, |
| 65 | +): |
| 66 | + x_desc = create_ragged_descriptor_device_3d( |
| 67 | + x_ptr, |
| 68 | + shape_0, shape_1, shape_2, |
| 69 | + stride_0, stride_1, stride_2, |
| 70 | + block_shape_0, block_shape_1, block_shape_2, |
| 71 | + ragged_dim, |
| 72 | + ) |
| 73 | + y_desc = create_ragged_descriptor_device_3d( |
| 74 | + y_ptr, |
| 75 | + shape_0, shape_1, shape_2, |
| 76 | + stride_0, stride_1, stride_2, |
| 77 | + block_shape_0, block_shape_1, block_shape_2, |
| 78 | + ragged_dim, |
| 79 | + ) |
| 80 | + |
| 81 | + data = load_ragged(x_desc, x_off, num_slices, [0, 0, 0], ragged_dim) |
| 82 | + store_ragged(y_desc, y_off, num_slices, [0, 0, 0], data, ragged_dim) |
| 83 | + |
| 84 | + |
| 85 | +def _generate_test_params(): |
| 86 | + dtypes = ["float16", "float32"] |
| 87 | + modes = ["host", "device"] |
| 88 | + |
| 89 | + params = [] |
| 90 | + for dtype in dtypes: |
| 91 | + for mode in modes: |
| 92 | + # 2D tensors: only ragged_dim=0 is valid |
| 93 | + params.append((dtype, mode, 2, 0)) |
| 94 | + # 3D tensors: ragged_dim=0 and ragged_dim=1 are valid |
| 95 | + params.append((dtype, mode, 3, 0)) |
| 96 | + params.append((dtype, mode, 3, 1)) |
| 97 | + |
| 98 | + return params |
| 99 | + |
| 100 | + |
| 101 | +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer") |
| 102 | +@pytest.mark.parametrize( |
| 103 | + "dtype_name,descriptor_mode,ndim,ragged_dim", _generate_test_params() |
| 104 | +) |
| 105 | +def test_ragged_tma(dtype_name, descriptor_mode, ndim, ragged_dim): |
| 106 | + |
| 107 | + torch_dtype = getattr(torch, dtype_name) |
| 108 | + |
| 109 | + if ndim == 2: |
| 110 | + shape = [128, 80] |
| 111 | + strides = [80, 1] |
| 112 | + block_shape = [32, 128] |
| 113 | + else: # ndim == 3 |
| 114 | + if ragged_dim == 0: |
| 115 | + shape = [64, 32, 32] |
| 116 | + strides = [32 * 32, 32, 1] |
| 117 | + block_shape = [16, 16, 32] |
| 118 | + else: # ragged_dim == 1 |
| 119 | + shape = [64, 32, 32] |
| 120 | + strides = [32 * 32, 32, 1] |
| 121 | + block_shape = [32, 16, 32] |
| 122 | + |
| 123 | + src = torch.ones(shape, dtype=torch_dtype, device="cuda") |
| 124 | + dst = torch.zeros(shape, dtype=torch_dtype, device="cuda") |
| 125 | + |
| 126 | + num_slices = min(block_shape[ragged_dim] - 1, shape[ragged_dim] // 3) |
| 127 | + x_off = 0 |
| 128 | + y_off = (shape[ragged_dim] - num_slices) // 2 |
| 129 | + |
| 130 | + def alloc_fn(size: int, align: int, stream: Optional[int]): |
| 131 | + return torch.empty(size, dtype=torch.int8, device="cuda") |
| 132 | + |
| 133 | + triton.set_allocator(alloc_fn) |
| 134 | + |
| 135 | + if descriptor_mode == "host": |
| 136 | + x_desc = create_ragged_descriptor(src, block_shape, ragged_dim) |
| 137 | + y_desc = create_ragged_descriptor(dst, block_shape, ragged_dim) |
| 138 | + |
| 139 | + example_load_store_kernel_host_desc[(1,)]( |
| 140 | + x_desc, |
| 141 | + y_desc, |
| 142 | + x_off, |
| 143 | + y_off, |
| 144 | + num_slices, |
| 145 | + ragged_dim, |
| 146 | + ndim, |
| 147 | + ) |
| 148 | + else: |
| 149 | + if ndim == 2: |
| 150 | + example_load_store_kernel_device_desc_2d[(1,)]( |
| 151 | + src, |
| 152 | + dst, |
| 153 | + x_off, |
| 154 | + y_off, |
| 155 | + num_slices, |
| 156 | + shape[0], shape[1], |
| 157 | + strides[0], strides[1], |
| 158 | + block_shape[0], block_shape[1], |
| 159 | + ragged_dim, |
| 160 | + ) |
| 161 | + else: # ndim == 3 |
| 162 | + example_load_store_kernel_device_desc_3d[(1,)]( |
| 163 | + src, |
| 164 | + dst, |
| 165 | + x_off, |
| 166 | + y_off, |
| 167 | + num_slices, |
| 168 | + shape[0], shape[1], shape[2], |
| 169 | + strides[0], strides[1], strides[2], |
| 170 | + block_shape[0], block_shape[1], block_shape[2], |
| 171 | + ragged_dim, |
| 172 | + ) |
| 173 | + |
| 174 | + if ragged_dim == 0: |
| 175 | + if ndim == 2: |
| 176 | + before = dst[:y_off, : block_shape[1]] |
| 177 | + copied = dst[y_off : y_off + num_slices, : block_shape[1]] |
| 178 | + after = dst[y_off + num_slices :, : block_shape[1]] |
| 179 | + else: # ndim == 3 |
| 180 | + before = dst[:y_off, : block_shape[1], : block_shape[2]] |
| 181 | + copied = dst[y_off : y_off + num_slices, : block_shape[1], : block_shape[2]] |
| 182 | + after = dst[y_off + num_slices :, : block_shape[1], : block_shape[2]] |
| 183 | + else: # ragged_dim == 1 |
| 184 | + before = dst[: block_shape[0], :y_off, : block_shape[2]] |
| 185 | + copied = dst[: block_shape[0], y_off : y_off + num_slices, : block_shape[2]] |
| 186 | + after = dst[: block_shape[0], y_off + num_slices :, : block_shape[2]] |
| 187 | + |
| 188 | + res0 = torch.all(before == 0.0).item() |
| 189 | + res1 = torch.all(copied == 1.0).item() |
| 190 | + res2 = torch.all(after == 0.0).item() |
| 191 | + |
| 192 | + assert [res0, res1, res2] == [ |
| 193 | + True, |
| 194 | + True, |
| 195 | + True, |
| 196 | + ], f"Failed for {ndim}D {descriptor_mode} mode ragged_dim={ragged_dim}: before={res0}, copied={res1}, after={res2}" |
0 commit comments