Skip to content

Commit 4278d7f

Browse files
authored
[triton] Add tl.cat(can_reorder=False) implementation (#9312)
This resurrects the old PR that replaced the implementation entirely. I also fixed `tl.cat` to be equivalent in semantics to `torch.cat`
1 parent e9489f1 commit 4278d7f

3 files changed

Lines changed: 64 additions & 12 deletions

File tree

python/test/unit/language/test_core.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import triton
1616
import triton.language as tl
17+
from triton.tools.tensor_descriptor import TensorDescriptor
1718

1819
from triton._internal_testing import (
1920
integral_dtypes,
@@ -1918,27 +1919,59 @@ def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexp
19181919
@pytest.mark.interpreter
19191920
@pytest.mark.parametrize("dtype_str, num_warps",
19201921
[(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]])
1921-
def test_cat(dtype_str, num_warps, device):
1922+
@pytest.mark.parametrize("can_reorder", [True, False])
1923+
def test_cat(dtype_str, num_warps, can_reorder, device):
19221924
check_type_supported(dtype_str, device)
19231925

19241926
@triton.jit
1925-
def kernel(X, Y, Z, N: tl.constexpr):
1927+
def kernel(X, Y, Z, N: tl.constexpr, CAN_REORDER: tl.constexpr):
19261928
offs = tl.arange(0, N)
19271929
x = tl.load(X + offs)
19281930
y = tl.load(Y + offs)
1929-
z = tl.cat(x, y, can_reorder=True)
1931+
z = tl.cat(x, y, can_reorder=CAN_REORDER)
19301932
tl.store(Z + tl.arange(0, 2 * N), z)
19311933

19321934
x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str))
19331935
y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str))
1934-
z_ref = torch.cat([x, y], dim=0).sum()
1936+
z_ref = torch.cat([x, y], dim=0)
19351937
z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device)
1936-
kernel[(1, )](x, y, z, N=128, num_warps=num_warps)
1937-
assert z.sum() == z_ref
1938+
kernel[(1, )](x, y, z, N=128, num_warps=num_warps, CAN_REORDER=can_reorder)
1939+
assert z.sum() == z_ref.sum()
1940+
if not can_reorder:
1941+
torch.testing.assert_close(z, z_ref, atol=0, rtol=0)
19381942
# check if there's no duplicate value in z
19391943
assert z.unique().size(0) == z.size(0)
19401944

19411945

1946+
CAT_ND_SHAPES = ((128, ), (16, 32), (8, 16, 4), (2, 4, 8, 16))
1947+
CAT_ND_CASES = []
1948+
for shape in CAT_ND_SHAPES:
1949+
for dim in range(len(shape)):
1950+
CAT_ND_CASES.append(pytest.param(shape, dim, id=f"rank={len(shape)},dim={dim}"))
1951+
1952+
1953+
@pytest.mark.parametrize("shape, dim", CAT_ND_CASES)
1954+
def test_cat_nd(shape, dim, device):
1955+
1956+
@triton.jit
1957+
def kernel(x_desc, y_desc, z_desc, dim: tl.constexpr, shape: tl.constexpr):
1958+
rank: tl.constexpr = len(shape)
1959+
x = x_desc.load([0] * rank)
1960+
y = y_desc.load([0] * rank)
1961+
z = tl.cat(x, y, dim=dim)
1962+
z_desc.store([0] * rank, z)
1963+
1964+
x = torch.rand(shape, device=device)
1965+
y = torch.rand(shape, device=device)
1966+
z_ref = torch.cat([x, y], dim=dim)
1967+
z = torch.empty_like(z_ref)
1968+
x_desc = TensorDescriptor.from_tensor(x, block_shape=shape)
1969+
y_desc = TensorDescriptor.from_tensor(y, block_shape=shape)
1970+
z_desc = TensorDescriptor.from_tensor(z, block_shape=z_ref.shape)
1971+
kernel[(1, )](x_desc, y_desc, z_desc, dim=dim, shape=shape)
1972+
torch.testing.assert_close(z, z_ref, atol=0, rtol=0)
1973+
1974+
19421975
@pytest.mark.interpreter
19431976
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
19441977
@pytest.mark.parametrize("constant_field", ["value", "mask"])

python/triton/language/core.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,20 +1795,38 @@ def permute(input, *dims, _semantic=None):
17951795

17961796

17971797
@builtin
1798-
def cat(input, other, can_reorder=False, _semantic=None):
1798+
def cat(input, other, can_reorder=False, dim=0, _semantic=None):
17991799
"""
18001800
Concatenate the given blocks
18011801
18021802
:param input: The first input tensor.
18031803
:type input: Tensor
18041804
:param other: The second input tensor.
18051805
:type other: Tensor
1806-
:param reorder: Compiler hint. If true, the compiler is
1806+
:param can_reorder: Compiler hint. If true, the compiler is
18071807
allowed to reorder elements while concatenating inputs. Only use if the
18081808
order does not matter (e.g., result is only used in reduction ops).
1809-
Current implementation of `cat` supports only can_reorder=True.
1810-
"""
1811-
return _semantic.cat(input, other, can_reorder)
1809+
:type can_reorder: bool
1810+
:param dim: The dimension to concatenate along (used when can_reorder is False).
1811+
:type dim: int
1812+
"""
1813+
if can_reorder:
1814+
return _semantic.cat(input, other, can_reorder)
1815+
1816+
rank = len(input.shape)
1817+
assert rank == len(other.shape), f"tensors must have the same rank, got {rank} and {len(other.shape)}"
1818+
dim = _wrap_axis(_unwrap_if_constexpr(dim), rank)
1819+
assert all(input.shape[i] == other.shape[i] for i in builtins.range(rank) if i !=
1820+
dim), f"tensor dims must match except in the concat dimension {dim}, got {input.shape} and {other.shape}"
1821+
1822+
# Join introduces a new minor dim; move it before the concat dim and merge.
1823+
c = join(input, other, _semantic=_semantic)
1824+
order = list(builtins.range(rank))
1825+
order.insert(dim, rank)
1826+
c = permute(c, order, _semantic=_semantic)
1827+
new_shape = list(input.shape)
1828+
new_shape[dim] = input.shape[dim] + other.shape[dim]
1829+
return reshape(c, new_shape, _semantic=_semantic)
18121830

18131831

18141832
@builtin

python/triton/language/semantic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,8 @@ def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
694694

695695
def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
696696
if len(input.shape) != len(dims):
697-
raise ValueError("permute dims must have the same length as input shape")
697+
raise ValueError(
698+
f"permute dims must have the same length as input shape, got {len(input.shape)} and {len(dims)}")
698699
if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))):
699700
raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
700701

0 commit comments

Comments
 (0)