Skip to content

Commit 8a6d399

Browse files
authored
add argsort operator (#956)
* add argsort operator
1 parent aad5d1d commit 8a6d399

6 files changed

Lines changed: 523 additions & 1 deletion

File tree

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
.. _argsort_func:
2+
3+
argsort
4+
#######
5+
6+
Compute the indices that would sort the elements of a tensor in either ascending or descending order
7+
8+
.. doxygenfunction:: argsort(const InputOperator &a, const SortDirection_t dir)
9+
10+
Examples
11+
~~~~~~~~
12+
13+
.. literalinclude:: ../../../test/00_tensor/CUBTests.cu
14+
:language: cpp
15+
:start-after: example-begin argsort-test-1
16+
:end-before: example-end argsort-test-1
17+
:dedent:
18+
19+
20+
.. literalinclude:: ../../../test/00_tensor/CUBTests.cu
21+
:language: cpp
22+
:start-after: example-begin argsort-test-2
23+
:end-before: example-end argsort-test-2
24+
:dedent:
25+
26+
27+

include/matx/core/iterator.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,23 @@ struct RandomOperatorOutputIterator {
291291
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator==(const self_type &a, const self_type &b)
292292
{
293293
return a.offset_ == b.offset_;
294-
}
294+
}
295+
296+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator<(const self_type &a, const self_type &b) {
297+
return a.offset_ < b.offset_;
298+
}
299+
300+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator>(const self_type &a, const self_type &b) {
301+
return a.offset_ > b.offset_;
302+
}
303+
304+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator<=(const self_type &a, const self_type &b) {
305+
return a.offset_ <= b.offset_;
306+
}
307+
308+
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ friend bool operator>=(const self_type &a, const self_type &b) {
309+
return a.offset_ >= b.offset_;
310+
}
295311

296312
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() {
297313
return OperatorType::Rank();

include/matx/operators/argsort.h

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2025, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COpBRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COpBRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#pragma once
34+
35+
36+
#include "matx/core/type_utils.h"
37+
#include "matx/operators/base_operator.h"
38+
#include "matx/transforms/cub.h"
39+
40+
namespace matx {
41+
42+
43+
44+
namespace detail {
45+
template<typename OpA>
46+
class ArgsortOp : public BaseOp<ArgsortOp<OpA>>
47+
{
48+
private:
49+
typename detail::base_type_t<OpA> a_;
50+
SortDirection_t dir_;
51+
cuda::std::array<index_t, OpA::Rank()> out_dims_;
52+
mutable detail::tensor_impl_t<index_t, OpA::Rank()> tmp_out_;
53+
mutable index_t *ptr = nullptr;
54+
55+
public:
56+
using matxop = bool;
57+
using value_type = index_t;
58+
using matx_transform_op = bool;
59+
using sort_xform_op = bool;
60+
61+
__MATX_INLINE__ std::string str() const { return "argsort()"; }
62+
__MATX_INLINE__ ArgsortOp(const OpA &a, const SortDirection_t dir) : a_(a), dir_(dir) {
63+
for (int r = 0; r < Rank(); r++) {
64+
out_dims_[r] = a_.Size(r);
65+
}
66+
}
67+
68+
__MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; }
69+
70+
template <typename... Is>
71+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const {
72+
return tmp_out_(indices...);
73+
}
74+
75+
template <typename Out, typename Executor>
76+
void Exec(Out &&out, Executor &&ex) const {
77+
argsort_impl(cuda::std::get<0>(out), a_, dir_, ex);
78+
}
79+
80+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
81+
{
82+
return OpA::Rank();
83+
}
84+
85+
template <typename ShapeType, typename Executor>
86+
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
87+
{
88+
if constexpr (is_matx_op<OpA>()) {
89+
a_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
90+
}
91+
}
92+
93+
template <typename ShapeType, typename Executor>
94+
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
95+
{
96+
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
97+
98+
detail::AllocateTempTensor(tmp_out_, std::forward<Executor>(ex), out_dims_, &ptr);
99+
100+
Exec(cuda::std::make_tuple(tmp_out_), std::forward<Executor>(ex));
101+
}
102+
103+
template <typename ShapeType, typename Executor>
104+
__MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept
105+
{
106+
if constexpr (is_matx_op<OpA>()) {
107+
a_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
108+
}
109+
110+
matxFree(ptr);
111+
}
112+
113+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
114+
{
115+
return out_dims_[dim];
116+
}
117+
118+
};
119+
}
120+
121+
/**
122+
* Argsort rows of an operator
123+
*
124+
* Generates indices that would sort the rows of an operator.
125+
* Currently supported types are float, double, ints, and long ints (both signed
126+
* and unsigned). For a 1D operator, a linear sort is performed. For 2D and above
127+
* each row of the inner dimensions are batched and sorted separately.
128+
*
129+
* @note Temporary memory may be used during the sorting process, and about 4N will
130+
* be allocated, where N is the length of the tensor.
131+
*
132+
* @tparam InputOperator
133+
* Input type
134+
* @param a
135+
* Input operator
136+
* @param dir
137+
* Direction to sort (either SORT_DIR_ASC or SORT_DIR_DESC)
138+
* @returns Operator containing indices that would sort the tensor
139+
*/
140+
template <typename InputOperator>
141+
__MATX_INLINE__ auto argsort(const InputOperator &a, const SortDirection_t dir = SORT_DIR_ASC) {
142+
return detail::ArgsortOp(a, dir);
143+
}
144+
145+
}

include/matx/operators/operators.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,4 @@
129129
#include "matx/operators/argminmax.h"
130130
#include "matx/operators/all.h"
131131
#include "matx/operators/any.h"
132+
#include "matx/operators/argsort.h"

0 commit comments

Comments
 (0)