|
14 | 14 |
|
15 | 15 | import triton |
16 | 16 | import triton.language as tl |
| 17 | +from triton.tools.tensor_descriptor import TensorDescriptor |
17 | 18 |
|
18 | 19 | from triton._internal_testing import ( |
19 | 20 | integral_dtypes, |
@@ -1918,27 +1919,59 @@ def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexp |
1918 | 1919 | @pytest.mark.interpreter |
1919 | 1920 | @pytest.mark.parametrize("dtype_str, num_warps", |
1920 | 1921 | [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) |
1921 | | -def test_cat(dtype_str, num_warps, device): |
| 1922 | +@pytest.mark.parametrize("can_reorder", [True, False]) |
| 1923 | +def test_cat(dtype_str, num_warps, can_reorder, device): |
1922 | 1924 | check_type_supported(dtype_str, device) |
1923 | 1925 |
|
1924 | 1926 | @triton.jit |
1925 | | - def kernel(X, Y, Z, N: tl.constexpr): |
| 1927 | + def kernel(X, Y, Z, N: tl.constexpr, CAN_REORDER: tl.constexpr): |
1926 | 1928 | offs = tl.arange(0, N) |
1927 | 1929 | x = tl.load(X + offs) |
1928 | 1930 | y = tl.load(Y + offs) |
1929 | | - z = tl.cat(x, y, can_reorder=True) |
| 1931 | + z = tl.cat(x, y, can_reorder=CAN_REORDER) |
1930 | 1932 | tl.store(Z + tl.arange(0, 2 * N), z) |
1931 | 1933 |
|
1932 | 1934 | x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) |
1933 | 1935 | y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) |
1934 | | - z_ref = torch.cat([x, y], dim=0).sum() |
| 1936 | + z_ref = torch.cat([x, y], dim=0) |
1935 | 1937 | z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) |
1936 | | - kernel[(1, )](x, y, z, N=128, num_warps=num_warps) |
1937 | | - assert z.sum() == z_ref |
| 1938 | + kernel[(1, )](x, y, z, N=128, num_warps=num_warps, CAN_REORDER=can_reorder) |
| 1939 | + assert z.sum() == z_ref.sum() |
| 1940 | + if not can_reorder: |
| 1941 | + torch.testing.assert_close(z, z_ref, atol=0, rtol=0) |
1938 | 1942 | # check if there's no duplicate value in z |
1939 | 1943 | assert z.unique().size(0) == z.size(0) |
1940 | 1944 |
|
1941 | 1945 |
|
| 1946 | +CAT_ND_SHAPES = ((128, ), (16, 32), (8, 16, 4), (2, 4, 8, 16)) |
| 1947 | +CAT_ND_CASES = [] |
| 1948 | +for shape in CAT_ND_SHAPES: |
| 1949 | + for dim in range(len(shape)): |
| 1950 | + CAT_ND_CASES.append(pytest.param(shape, dim, id=f"rank={len(shape)},dim={dim}")) |
| 1951 | + |
| 1952 | + |
| 1953 | +@pytest.mark.parametrize("shape, dim", CAT_ND_CASES) |
| 1954 | +def test_cat_nd(shape, dim, device): |
| 1955 | + |
| 1956 | + @triton.jit |
| 1957 | + def kernel(x_desc, y_desc, z_desc, dim: tl.constexpr, shape: tl.constexpr): |
| 1958 | + rank: tl.constexpr = len(shape) |
| 1959 | + x = x_desc.load([0] * rank) |
| 1960 | + y = y_desc.load([0] * rank) |
| 1961 | + z = tl.cat(x, y, dim=dim) |
| 1962 | + z_desc.store([0] * rank, z) |
| 1963 | + |
| 1964 | + x = torch.rand(shape, device=device) |
| 1965 | + y = torch.rand(shape, device=device) |
| 1966 | + z_ref = torch.cat([x, y], dim=dim) |
| 1967 | + z = torch.empty_like(z_ref) |
| 1968 | + x_desc = TensorDescriptor.from_tensor(x, block_shape=shape) |
| 1969 | + y_desc = TensorDescriptor.from_tensor(y, block_shape=shape) |
| 1970 | + z_desc = TensorDescriptor.from_tensor(z, block_shape=z_ref.shape) |
| 1971 | + kernel[(1, )](x_desc, y_desc, z_desc, dim=dim, shape=shape) |
| 1972 | + torch.testing.assert_close(z, z_ref, atol=0, rtol=0) |
| 1973 | + |
| 1974 | + |
1942 | 1975 | @pytest.mark.interpreter |
1943 | 1976 | @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) |
1944 | 1977 | @pytest.mark.parametrize("constant_field", ["value", "mask"]) |
|
0 commit comments