Skip to content

Commit 743f178

Browse files
authored
Add descending flag to topk (#9355)
1 parent 03d956b commit 743f178

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

python/test/unit/language/test_standard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
4343
if k is None or x.numel < k:
4444
z = tl.sort(x, descending=descending)
4545
else:
46-
z = tl.topk(x, k)
46+
z = tl.topk(x, k, descending=descending)
4747
offs_z = offs_m[:, None] * stride_zm + offs_z_n[None, :]
4848
tl.store(Z + offs_z, z)
4949

@@ -54,7 +54,7 @@ def sort_kernel(X, stride_xm, Z, stride_zm, M: tl.constexpr, N: tl.constexpr, k:
5454
if k is None or x.numel() < k:
5555
y = torch.sort(x, descending=descending)[0]
5656
else:
57-
y = torch.topk(x, k=k).values
57+
y = torch.topk(x, k=k, largest=descending).values
5858
sort_kernel[(1, )](x, x.stride(0), z, z.stride(0), M, N, k, descending, num_warps=8)
5959
assert (y == z).all(), (y, z)
6060

python/triton/language/standard.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,9 +466,9 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE
466466

467467

468468
@jit
469-
def topk(x, k: core.constexpr, dim: core.constexpr = None):
469+
def topk(x, k: core.constexpr, dim: core.constexpr = None, descending: core.constexpr = True):
470470
"""
471-
Returns the k largest elements of the input tensor along the specified dimension.
471+
Returns the k largest (or smallest) elements of the input tensor along the specified dimension.
472472
473473
The elements are returned in sorted order (largest first).
474474
@@ -479,6 +479,8 @@ def topk(x, k: core.constexpr, dim: core.constexpr = None):
479479
:param dim: The dimension along which to find the top k elements.
480480
If None, uses the last dimension. Currently only the last dimension is supported.
481481
:type dim: int, optional
482+
:param descending: If set to True, returns k largest elements. If set to False, returns k smallest elements.
483+
:type descending: bool, optional
482484
:return: A tensor containing the k largest elements along the specified dimension.
483485
:rtype: Tensor
484486
@@ -488,7 +490,7 @@ def topk(x, k: core.constexpr, dim: core.constexpr = None):
488490
x = tl.arange(0, 16)
489491
top4 = tl.topk(x, 4) # Returns [15, 14, 13, 12]
490492
"""
491-
return sort_impl(x, k=k, dim=dim, descending=True)
493+
return sort_impl(x, k=k, dim=dim, descending=descending)
492494

493495

494496
@jit

0 commit comments

Comments
 (0)