22#include < cub/block/block_reduce.cuh>
33#include " rxmesh/util/macros.h"
44
5+ #include " rxmesh/arg_ops.h"
56
67namespace rxmesh {
78
@@ -10,7 +11,7 @@ class Attribute;
1011
1112namespace detail {
1213
13- template <class T , uint32_t blockSize >
14+ template <uint32_t blockSize, class T >
1415__device__ __forceinline__ void cub_block_sum (const T thread_val,
1516 T* d_block_output)
1617{
@@ -22,6 +23,23 @@ __device__ __forceinline__ void cub_block_sum(const T thread_val,
2223 }
2324}
2425
26+ template <uint32_t blockSize, class T , typename ReductionOp>
27+ __device__ __forceinline__ void cub_block_reduce (const T thread_val,
28+ T* d_block_output,
29+ ReductionOp reduction_op)
30+ {
31+ typedef cub::BlockReduce<T, blockSize> BlockReduce;
32+
33+ __shared__ typename BlockReduce::TempStorage temp_storage;
34+
35+ T block_aggregate =
36+ BlockReduce (temp_storage).Reduce (thread_val, reduction_op);
37+
38+ if (threadIdx .x == 0 ) {
39+ d_block_output[blockIdx .x ] = block_aggregate;
40+ }
41+ }
42+
2543template <class T , uint32_t blockSize, typename HandleT>
2644__launch_bounds__ (blockSize) __global__
2745 void norm2_kernel (const Attribute<T, HandleT> X,
@@ -52,7 +70,7 @@ __launch_bounds__(blockSize) __global__
5270 }
5371 }
5472
55- cub_block_sum<T, blockSize>(thread_val, d_block_output);
73+ cub_block_sum<blockSize>(thread_val, d_block_output);
5674 }
5775}
5876
@@ -90,95 +108,53 @@ __launch_bounds__(blockSize) __global__
90108 }
91109 }
92110
93- cub_block_sum<T, blockSize>(thread_val, d_block_output);
111+ cub_block_sum<blockSize>(thread_val, d_block_output);
94112 }
95113}
96114
97- template <typename HandleT, typename T>
98- struct CustomMaxPair
99- {
100- __host__ __device__ CustomMaxPair ()
101- {
102- default_val = (std::numeric_limits<T>::lowest ());
103- }
104-
105- __device__ __forceinline__ cub::KeyValuePair<HandleT, T> operator ()(
106- const cub::KeyValuePair<HandleT, T>& a,
107- const cub::KeyValuePair<HandleT, T>& b) const
108- {
109- return (b.value > a.value ) ? b : a;
110- }
111- T default_val;
112- };
113-
114- template <typename HandleT, typename T>
115- struct CustomMinPair
116- {
117- __host__ __device__ CustomMinPair ()
118- {
119- default_val = (std::numeric_limits<T>::max ());
120- }
121- __device__ __forceinline__ cub::KeyValuePair<HandleT, T> operator ()(
122- const cub::KeyValuePair<HandleT, T>& a,
123- const cub::KeyValuePair<HandleT, T>& b) const
124- {
125- return (b.value < a.value ) ? b : a;
126- }
127- T default_val;
128- };
129-
130115template <class T , uint32_t blockSize, typename HandleT, typename Operation>
131116__launch_bounds__ (blockSize) __global__
132117 void arg_minmax_kernel (const Attribute<T, HandleT> X,
133- uint32_t attribute_id,
134- Operation op, // can be either max or min operation
135- const uint32_t num_patches,
136- const uint32_t num_attributes,
137- cub:: KeyValuePair<HandleT, T>* d_block_output)
118+ uint32_t attribute_id,
119+ Operation reduction_op,
120+ const uint32_t num_patches,
121+ const uint32_t num_attributes,
122+ KeyValuePair<HandleT, T>* d_block_output)
138123{
139124 using LocalT = typename HandleT::LocalT;
140125
141- assert (X.get_num_attributes () == 1 ); // we can only take arg max for a scalar attribute
142-
143126 uint32_t p_id = blockIdx .x ;
144127 if (p_id < num_patches) {
145- const uint16_t element_per_patch = X.size (p_id);
146- cub:: KeyValuePair<HandleT, T> thread_val;
147- thread_val.value = op .default_val ;
128+ const uint16_t element_per_patch = X.size (p_id);
129+ KeyValuePair<HandleT, T> thread_val;
130+ thread_val.value = reduction_op .default_val () ;
148131 thread_val.key = HandleT (p_id, threadIdx .x );
149132 for (uint16_t i = threadIdx .x ; i < element_per_patch; i += blockSize) {
150133
151134 if (X.get_patch_info (p_id).is_owned (LocalT (i)) &&
152135 !X.get_patch_info (p_id).is_deleted (LocalT (i))) {
153136
154- if (attribute_id != INVALID32 )
155- {
156- HandleT handle (p_id, i);
157- cub::KeyValuePair<HandleT, T> current_pair (handle, X (p_id, i, attribute_id));
158- thread_val = op (thread_val, current_pair);
159- }
160- else {
161- for (uint32_t j = 0 ; j < num_attributes; ++j)
162- {
163- HandleT handle (p_id, i);
164- cub::KeyValuePair<HandleT, T> current_pair (handle, X (p_id, i, j));
165- thread_val = op (thread_val, current_pair);
137+ if (attribute_id != INVALID32) {
138+ HandleT handle (p_id, i);
139+ KeyValuePair<HandleT, T> current_pair (
140+ handle, X (p_id, i, attribute_id));
141+ thread_val = reduction_op (thread_val, current_pair);
142+ } else {
143+ for (uint32_t j = 0 ; j < num_attributes; ++j) {
144+ HandleT handle (p_id, i);
145+ KeyValuePair<HandleT, T> current_pair (handle,
146+ X (p_id, i, j));
147+ thread_val = reduction_op (thread_val, current_pair);
166148 }
167149 }
168150 }
169151 }
170- typedef cub::BlockReduce<cub::KeyValuePair<HandleT, T>, blockSize> BlockReduce;
171- __shared__ typename BlockReduce::TempStorage temp_storage;
172- cub::KeyValuePair<HandleT, T> block_aggregate = BlockReduce (temp_storage).Reduce (thread_val, op);
173- if (threadIdx .x == 0 )
174- {
175- d_block_output[blockIdx .x ] = block_aggregate;
176- }
152+
153+ cub_block_reduce<blockSize>(thread_val, d_block_output, reduction_op);
177154 }
178155}
179156
180157
181-
182158template <class T , uint32_t blockSize, typename ReductionOp, typename HandleT>
183159__launch_bounds__ (blockSize) __global__
184160 void generic_reduce (const Attribute<T, HandleT> X,
@@ -209,14 +185,8 @@ __launch_bounds__(blockSize) __global__
209185 }
210186 }
211187 }
212- typedef cub::BlockReduce<T, blockSize> BlockReduce;
213- __shared__ typename BlockReduce::TempStorage temp_storage;
214188
215- T block_aggregate =
216- BlockReduce (temp_storage).Reduce (thread_val, reduction_op);
217- if (threadIdx .x == 0 ) {
218- d_block_output[blockIdx .x ] = block_aggregate;
219- }
189+ cub_block_reduce<blockSize>(thread_val, d_block_output, reduction_op);
220190 }
221191}
222192
0 commit comments