@@ -466,9 +466,9 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE
466466
467467
468468@jit
469- def topk (x , k : core .constexpr , dim : core .constexpr = None ):
469+ def topk (x , k : core .constexpr , dim : core .constexpr = None , descending : core . constexpr = True ):
470470 """
471- Returns the k largest elements of the input tensor along the specified dimension.
471+ Returns the k largest (or smallest) elements of the input tensor along the specified dimension.
472472
473473 The elements are returned in sorted order (largest first).
474474
@@ -479,6 +479,8 @@ def topk(x, k: core.constexpr, dim: core.constexpr = None):
479479 :param dim: The dimension along which to find the top k elements.
480480 If None, uses the last dimension. Currently only the last dimension is supported.
481481 :type dim: int, optional
482+ :param descending: If set to True, returns k largest elements. If set to False, returns k smallest elements.
483+ :type descending: bool, optional
482484 :return: A tensor containing the k largest elements along the specified dimension.
483485 :rtype: Tensor
484486
@@ -488,7 +490,7 @@ def topk(x, k: core.constexpr, dim: core.constexpr = None):
488490 x = tl.arange(0, 16)
489491 top4 = tl.topk(x, 4) # Returns [15, 14, 13, 12]
490492 """
491- return sort_impl (x , k = k , dim = dim , descending = True )
493+ return sort_impl (x , k = k , dim = dim , descending = descending )
492494
493495
494496@jit
0 commit comments