Skip to content

Commit 8f2c29a

Browse files
committed
couple of minor fixes to arg min/max
1 parent 76e6fbf commit 8f2c29a

File tree

5 files changed

+191
-207
lines changed

5 files changed

+191
-207
lines changed

include/rxmesh/arg_ops.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include <limits>
4+
5+
namespace rxmesh {
6+
7+
template <typename HandleT, typename T>
8+
using KeyValuePair = cub::KeyValuePair<HandleT, T>;
9+
10+
namespace detail {
11+
12+
template <typename HandleT, typename T>
13+
struct ArgMaxOp
14+
{
15+
constexpr T default_val() const
16+
{
17+
return std::numeric_limits<T>::lowest();
18+
}
19+
20+
__device__ __forceinline__ KeyValuePair<HandleT, T> operator()(
21+
const KeyValuePair<HandleT, T>& a,
22+
const KeyValuePair<HandleT, T>& b) const
23+
{
24+
return (b.value > a.value) ? b : a;
25+
}
26+
};
27+
28+
29+
template <typename HandleT, typename T>
30+
struct ArgMinOp
31+
{
32+
constexpr T default_val() const
33+
{
34+
return std::numeric_limits<T>::max();
35+
}
36+
37+
__device__ __forceinline__ KeyValuePair<HandleT, T> operator()(
38+
const KeyValuePair<HandleT, T>& a,
39+
const KeyValuePair<HandleT, T>& b) const
40+
{
41+
return (b.value < a.value) ? b : a;
42+
}
43+
};
44+
45+
46+
} // namespace detail
47+
} // namespace rxmesh

include/rxmesh/cavity_manager_impl.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3826,7 +3826,7 @@ CavityManager<blockThreads, cop>::populate_correspondence(
38263826
const LPPair lp =
38273827
m_patch_info.get_lp<HandleT>().find(b, s_table, s_stash);
38283828

3829-
assert(lp.local_id == b);
3829+
assert(lp.local_id() == b);
38303830

38313831
// inner
38323832
for (int c = 0; c < q_num_elements; ++c) {

include/rxmesh/kernels/attribute.cuh

Lines changed: 43 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <cub/block/block_reduce.cuh>
33
#include "rxmesh/util/macros.h"
44

5+
#include "rxmesh/arg_ops.h"
56

67
namespace rxmesh {
78

@@ -10,7 +11,7 @@ class Attribute;
1011

1112
namespace detail {
1213

13-
template <class T, uint32_t blockSize>
14+
template <uint32_t blockSize, class T>
1415
__device__ __forceinline__ void cub_block_sum(const T thread_val,
1516
T* d_block_output)
1617
{
@@ -22,6 +23,23 @@ __device__ __forceinline__ void cub_block_sum(const T thread_val,
2223
}
2324
}
2425

26+
template <uint32_t blockSize, class T, typename ReductionOp>
27+
__device__ __forceinline__ void cub_block_reduce(const T thread_val,
28+
T* d_block_output,
29+
ReductionOp reduction_op)
30+
{
31+
typedef cub::BlockReduce<T, blockSize> BlockReduce;
32+
33+
__shared__ typename BlockReduce::TempStorage temp_storage;
34+
35+
T block_aggregate =
36+
BlockReduce(temp_storage).Reduce(thread_val, reduction_op);
37+
38+
if (threadIdx.x == 0) {
39+
d_block_output[blockIdx.x] = block_aggregate;
40+
}
41+
}
42+
2543
template <class T, uint32_t blockSize, typename HandleT>
2644
__launch_bounds__(blockSize) __global__
2745
void norm2_kernel(const Attribute<T, HandleT> X,
@@ -52,7 +70,7 @@ __launch_bounds__(blockSize) __global__
5270
}
5371
}
5472

55-
cub_block_sum<T, blockSize>(thread_val, d_block_output);
73+
cub_block_sum<blockSize>(thread_val, d_block_output);
5674
}
5775
}
5876

@@ -90,95 +108,53 @@ __launch_bounds__(blockSize) __global__
90108
}
91109
}
92110

93-
cub_block_sum<T, blockSize>(thread_val, d_block_output);
111+
cub_block_sum<blockSize>(thread_val, d_block_output);
94112
}
95113
}
96114

97-
template <typename HandleT, typename T>
98-
struct CustomMaxPair
99-
{
100-
__host__ __device__ CustomMaxPair()
101-
{
102-
default_val = (std::numeric_limits<T>::lowest());
103-
}
104-
105-
__device__ __forceinline__ cub::KeyValuePair<HandleT, T> operator()(
106-
const cub::KeyValuePair<HandleT, T>& a,
107-
const cub::KeyValuePair<HandleT, T>& b) const
108-
{
109-
return (b.value > a.value) ? b : a;
110-
}
111-
T default_val;
112-
};
113-
114-
template <typename HandleT, typename T>
115-
struct CustomMinPair
116-
{
117-
__host__ __device__ CustomMinPair()
118-
{
119-
default_val = (std::numeric_limits<T>::max());
120-
}
121-
__device__ __forceinline__ cub::KeyValuePair<HandleT, T> operator()(
122-
const cub::KeyValuePair<HandleT, T>& a,
123-
const cub::KeyValuePair<HandleT, T>& b) const
124-
{
125-
return (b.value < a.value) ? b : a;
126-
}
127-
T default_val;
128-
};
129-
130115
template <class T, uint32_t blockSize, typename HandleT, typename Operation>
131116
__launch_bounds__(blockSize) __global__
132117
void arg_minmax_kernel(const Attribute<T, HandleT> X,
133-
uint32_t attribute_id,
134-
Operation op, //can be either max or min operation
135-
const uint32_t num_patches,
136-
const uint32_t num_attributes,
137-
cub::KeyValuePair<HandleT, T>* d_block_output)
118+
uint32_t attribute_id,
119+
Operation reduction_op,
120+
const uint32_t num_patches,
121+
const uint32_t num_attributes,
122+
KeyValuePair<HandleT, T>* d_block_output)
138123
{
139124
using LocalT = typename HandleT::LocalT;
140125

141-
assert(X.get_num_attributes() == 1); //we can only take arg max for a scalar attribute
142-
143126
uint32_t p_id = blockIdx.x;
144127
if (p_id < num_patches) {
145-
const uint16_t element_per_patch = X.size(p_id);
146-
cub::KeyValuePair<HandleT, T> thread_val;
147-
thread_val.value = op.default_val;
128+
const uint16_t element_per_patch = X.size(p_id);
129+
KeyValuePair<HandleT, T> thread_val;
130+
thread_val.value = reduction_op.default_val();
148131
thread_val.key = HandleT(p_id, threadIdx.x);
149132
for (uint16_t i = threadIdx.x; i < element_per_patch; i += blockSize) {
150133

151134
if (X.get_patch_info(p_id).is_owned(LocalT(i)) &&
152135
!X.get_patch_info(p_id).is_deleted(LocalT(i))) {
153136

154-
if (attribute_id != INVALID32 )
155-
{
156-
HandleT handle(p_id, i);
157-
cub::KeyValuePair<HandleT, T> current_pair(handle, X(p_id, i, attribute_id));
158-
thread_val = op(thread_val, current_pair);
159-
}
160-
else {
161-
for (uint32_t j = 0; j < num_attributes; ++j)
162-
{
163-
HandleT handle(p_id, i);
164-
cub::KeyValuePair<HandleT, T> current_pair(handle, X(p_id, i, j));
165-
thread_val = op(thread_val, current_pair);
137+
if (attribute_id != INVALID32) {
138+
HandleT handle(p_id, i);
139+
KeyValuePair<HandleT, T> current_pair(
140+
handle, X(p_id, i, attribute_id));
141+
thread_val = reduction_op(thread_val, current_pair);
142+
} else {
143+
for (uint32_t j = 0; j < num_attributes; ++j) {
144+
HandleT handle(p_id, i);
145+
KeyValuePair<HandleT, T> current_pair(handle,
146+
X(p_id, i, j));
147+
thread_val = reduction_op(thread_val, current_pair);
166148
}
167149
}
168150
}
169151
}
170-
typedef cub::BlockReduce<cub::KeyValuePair<HandleT, T>, blockSize> BlockReduce;
171-
__shared__ typename BlockReduce::TempStorage temp_storage;
172-
cub::KeyValuePair<HandleT, T> block_aggregate = BlockReduce(temp_storage).Reduce(thread_val, op);
173-
if (threadIdx.x == 0)
174-
{
175-
d_block_output[blockIdx.x] = block_aggregate;
176-
}
152+
153+
cub_block_reduce<blockSize>(thread_val, d_block_output, reduction_op);
177154
}
178155
}
179156

180157

181-
182158
template <class T, uint32_t blockSize, typename ReductionOp, typename HandleT>
183159
__launch_bounds__(blockSize) __global__
184160
void generic_reduce(const Attribute<T, HandleT> X,
@@ -209,14 +185,8 @@ __launch_bounds__(blockSize) __global__
209185
}
210186
}
211187
}
212-
typedef cub::BlockReduce<T, blockSize> BlockReduce;
213-
__shared__ typename BlockReduce::TempStorage temp_storage;
214188

215-
T block_aggregate =
216-
BlockReduce(temp_storage).Reduce(thread_val, reduction_op);
217-
if (threadIdx.x == 0) {
218-
d_block_output[blockIdx.x] = block_aggregate;
219-
}
189+
cub_block_reduce<blockSize>(thread_val, d_block_output, reduction_op);
220190
}
221191
}
222192

0 commit comments

Comments
 (0)