Skip to content

Commit 8ecff84

Browse files
committed
use group::ballot not group.ballot in code
1 parent c043872 commit 8ecff84

23 files changed

Lines changed: 129 additions & 111 deletions

common/cuda_hip/components/bitvector.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ __global__ __launch_bounds__(default_block_size) void from_predicate(
4949
group::tiled_partition<block_size>(group::this_thread_block());
5050
const auto i = static_cast<IndexType>(subwarp_base + subwarp.thread_rank());
5151
const auto bit = i < size ? predicate(i) : false;
52-
const auto mask = subwarp.ballot(bit);
52+
const auto mask = group::ballot(subwarp, bit);
5353
if (subwarp.thread_rank() == 0) {
5454
bits[subwarp_id] = mask;
5555
popcounts[subwarp_id] = gko::detail::popcount(mask);

common/cuda_hip/components/merging.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -7,6 +7,7 @@
77

88

99
#include "common/cuda_hip/base/math.hpp"
10+
#include "common/cuda_hip/components/cooperative_groups.hpp"
1011
#include "common/cuda_hip/components/intrinsics.hpp"
1112
#include "common/cuda_hip/components/searching.hpp"
1213
#include "core/base/utils.hpp"
@@ -91,7 +92,7 @@ __forceinline__ __device__ detail::merge_result<ValueType> group_merge_step(
9192
auto a_val = group.shfl(a, a_idx);
9293
auto b_val = group.shfl(b, b_idx);
9394
auto cmp = a_val < b_val;
94-
auto a_advance = popcnt(group.ballot(cmp));
95+
auto a_advance = popcnt(group::ballot(group, cmp));
9596
auto b_advance = int(group.size()) - a_advance;
9697

9798
return {a_val, b_val, a_idx, b_idx, a_advance, b_advance};
@@ -208,7 +209,7 @@ __forceinline__ __device__ void group_match(const ValueType* __restrict__ a,
208209
a, a_size, b, b_size, group,
209210
[&](IndexType a_idx, ValueType a_val, IndexType b_idx, ValueType b_val,
210211
IndexType, bool valid) {
211-
auto matchmask = group.ballot(a_val == b_val && valid);
212+
auto matchmask = group::ballot(group, a_val == b_val && valid);
212213
match_fn(a_val, a_idx, b_idx, matchmask, a_val == b_val && valid);
213214
return a_idx < a_size && b_idx < b_size;
214215
});

common/cuda_hip/components/searching.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -7,6 +7,7 @@
77

88

99
#include "common/cuda_hip/base/config.hpp"
10+
#include "common/cuda_hip/components/cooperative_groups.hpp"
1011
#include "common/cuda_hip/components/intrinsics.hpp"
1112

1213

@@ -168,7 +169,7 @@ __forceinline__ __device__ IndexType group_wide_search(IndexType offset,
168169
*/
169170
auto base_idx = (group_pos - 1) * group.size() + 1;
170171
auto idx = base_idx + group.thread_rank();
171-
auto pos = ffs(group.ballot(idx >= length || p(offset + idx))) - 1;
172+
auto pos = ffs(group::ballot(group, idx >= length || p(offset + idx))) - 1;
172173
return offset + base_idx + pos;
173174
}
174175

@@ -205,7 +206,7 @@ __forceinline__ __device__ IndexType group_ary_search(IndexType offset,
205206
while (length > group.size()) {
206207
auto stride = length / group.size();
207208
auto idx = offset + group.thread_rank() * stride;
208-
auto mask = group.ballot(p(idx));
209+
auto mask = group::ballot(group, p(idx));
209210
// if the mask is 0, the partition point is in the last block
210211
// if the mask is ~0, the partition point is in the first block
211212
// otherwise, we go to the last block that returned a 0.
@@ -217,7 +218,7 @@ __forceinline__ __device__ IndexType group_ary_search(IndexType offset,
217218
auto idx = offset + group.thread_rank();
218219
// if the mask is 0, the partition point is at the end
219220
// otherwise it is the first set bit
220-
auto mask = group.ballot(idx >= end || p(idx));
221+
auto mask = group::ballot(group, idx >= end || p(idx));
221222
auto pos = mask == 0 ? group.size() : ffs(mask) - 1;
222223
return offset + pos;
223224
}

common/cuda_hip/factorization/cholesky_kernels.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ __global__ __launch_bounds__(default_block_size) void symbolic_factorize(
143143
const auto next_node =
144144
nz < lower_end - 1 ? postorder_cols[nz + 1] : diag_postorder;
145145
bool pred = node < next_node;
146-
auto mask = subwarp.ballot(pred);
146+
auto mask = group::ballot(subwarp, pred);
147147
while (mask) {
148148
if (pred) {
149149
const auto out_nz = out_base + popcnt(mask & prefix_mask);
@@ -152,7 +152,7 @@ __global__ __launch_bounds__(default_block_size) void symbolic_factorize(
152152
pred = node < next_node;
153153
}
154154
out_base += popcnt(mask);
155-
mask = subwarp.ballot(pred);
155+
mask = group::ballot(subwarp, pred);
156156
}
157157
}
158158
// add diagonal entry

common/cuda_hip/factorization/factorization_kernels.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ __launch_bounds__(default_block_size) void add_missing_diagonal_elements(
187187
thread_is_active ? old_col_idxs[old_idx] : IndexType{};
188188
// automatically false if thread is not active
189189
bool diagonal_add_required = !diagonal_added && row < col_idx;
190-
auto ballot = subwarp_grp.ballot(diagonal_add_required);
190+
auto ballot = group::ballot(subwarp_grp, diagonal_add_required);
191191

192192
if (ballot) {
193193
auto first_subwarp_idx = ffs(ballot) - 1;

common/cuda_hip/factorization/par_ict_kernels.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "common/cuda_hip/base/math.hpp"
1414
#include "common/cuda_hip/base/runtime.hpp"
15+
#include "common/cuda_hip/components/cooperative_groups.hpp"
1516
#include "common/cuda_hip/components/intrinsics.hpp"
1617
#include "common/cuda_hip/components/memory.hpp"
1718
#include "common/cuda_hip/components/merging.hpp"
@@ -75,8 +76,8 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_nnz(
7576
IndexType llh_col, IndexType out_nz, bool valid) {
7677
auto col = min(a_col, llh_col);
7778
// count the number of unique elements being merged
78-
count +=
79-
popcnt(subwarp.ballot(col <= row && a_col != llh_col && valid));
79+
count += popcnt(group::ballot(
80+
subwarp, col <= row && a_col != llh_col && valid));
8081
return true;
8182
});
8283
if (subwarp.thread_rank() == 0) {
@@ -149,7 +150,8 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_init(
149150
auto llh_cur_val = subwarp.shfl(llh_val, merge_result.b_idx);
150151
auto valid = out_begin + lane < out_size;
151152
// check if the previous thread has matching columns
152-
auto equal_mask = subwarp.ballot(a_cur_col == llh_cur_col && valid);
153+
auto equal_mask =
154+
group::ballot(subwarp, a_cur_col == llh_cur_col && valid);
153155
auto prev_equal_mask = equal_mask << 1 | skip_first;
154156
skip_first = bool(equal_mask >> (subwarp_size - 1));
155157
auto prev_equal = bool(prev_equal_mask & lanemask_eq);
@@ -179,7 +181,7 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_init(
179181
// determine which threads will write output to L
180182
auto use_l = l_cur_col == r_col;
181183
auto do_write = !prev_equal && valid && r_col <= row;
182-
auto l_new_advance_mask = subwarp.ballot(do_write);
184+
auto l_new_advance_mask = group::ballot(subwarp, do_write);
183185
// store values
184186
if (do_write) {
185187
auto diag = l_vals[l_row_ptrs[r_col + 1] - 1];
@@ -192,7 +194,7 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_init(
192194
// advance *_begin offsets
193195
auto a_advance = merge_result.a_advance;
194196
auto llh_advance = merge_result.b_advance;
195-
auto l_advance = popcnt(subwarp.ballot(do_write && use_l));
197+
auto l_advance = popcnt(group::ballot(subwarp, do_write && use_l));
196198
auto l_new_advance = popcnt(l_new_advance_mask);
197199
a_begin += a_advance;
198200
llh_begin += llh_advance;
@@ -295,7 +297,7 @@ __global__ __launch_bounds__(default_block_size) void ict_sweep(
295297
conj(load_relaxed(l_vals + (lh_idx + lh_col_begin)));
296298
}
297299
// remember the transposed element
298-
auto found_transp = subwarp.ballot(lh_row == row);
300+
auto found_transp = group::ballot(subwarp, lh_row == row);
299301
if (found_transp) {
300302
lh_nz =
301303
subwarp.shfl(lh_idx + lh_col_begin, ffs(found_transp) - 1);

common/cuda_hip/factorization/par_ilut_filter_kernels.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -50,7 +50,7 @@ __device__ void abstract_filter_impl(const IndexType* row_ptrs,
5050
for (IndexType step = 0; step < num_steps; ++step) {
5151
auto idx = begin + lane + step * subwarp_size;
5252
auto keep = idx < end && pred(idx, begin, end);
53-
auto mask = subwarp.ballot(keep);
53+
auto mask = group::ballot(subwarp, keep);
5454
step_cb(row, idx, keep, popcnt(mask), popcnt(mask & lane_prefix_mask));
5555
}
5656
finish_cb(row, lane);

common/cuda_hip/factorization/par_ilut_spgeam_kernels.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -74,10 +74,10 @@ __global__ __launch_bounds__(default_block_size) void tri_spgeam_nnz(
7474
IndexType out_nz, bool valid) {
7575
auto col = min(a_col, lu_col);
7676
// count the number of unique elements being merged
77-
l_count +=
78-
popcnt(subwarp.ballot(col <= row && a_col != lu_col && valid));
79-
u_count +=
80-
popcnt(subwarp.ballot(col >= row && a_col != lu_col && valid));
77+
l_count += popcnt(
78+
group::ballot(subwarp, col <= row && a_col != lu_col && valid));
79+
u_count += popcnt(
80+
group::ballot(subwarp, col >= row && a_col != lu_col && valid));
8181
return true;
8282
});
8383
if (subwarp.thread_rank() == 0) {
@@ -172,7 +172,8 @@ __global__ __launch_bounds__(default_block_size) void tri_spgeam_init(
172172
auto lu_cur_val = subwarp.shfl(lu_val, merge_result.b_idx);
173173
auto valid = out_begin + lane < out_size;
174174
// check if the previous thread has matching columns
175-
auto equal_mask = subwarp.ballot(a_cur_col == lu_cur_col && valid);
175+
auto equal_mask =
176+
group::ballot(subwarp, a_cur_col == lu_cur_col && valid);
176177
auto prev_equal_mask = equal_mask << 1 | skip_first;
177178
skip_first = bool(equal_mask >> (subwarp_size - 1));
178179
auto prev_equal = bool(prev_equal_mask & lanemask_eq);
@@ -197,9 +198,9 @@ __global__ __launch_bounds__(default_block_size) void tri_spgeam_init(
197198
// determine which threads will write output to L or U
198199
auto use_lpu = lpu_cur_col == r_col;
199200
auto l_new_advance_mask =
200-
subwarp.ballot(r_col <= row && !prev_equal && valid);
201+
group::ballot(subwarp, r_col <= row && !prev_equal && valid);
201202
auto u_new_advance_mask =
202-
subwarp.ballot(r_col >= row && !prev_equal && valid);
203+
group::ballot(subwarp, r_col >= row && !prev_equal && valid);
203204
// store values
204205
if (!prev_equal && valid) {
205206
auto diag =
@@ -222,7 +223,7 @@ __global__ __launch_bounds__(default_block_size) void tri_spgeam_init(
222223
auto a_advance = merge_result.a_advance;
223224
auto lu_advance = merge_result.b_advance;
224225
auto lpu_advance =
225-
popcnt(subwarp.ballot(use_lpu && !prev_equal && valid));
226+
popcnt(group::ballot(subwarp, use_lpu && !prev_equal && valid));
226227
auto l_new_advance = popcnt(l_new_advance_mask);
227228
auto u_new_advance = popcnt(u_new_advance_mask);
228229
a_begin += a_advance;

common/cuda_hip/factorization/par_ilut_sweep_kernels.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "common/cuda_hip/base/math.hpp"
1212
#include "common/cuda_hip/base/runtime.hpp"
13+
#include "common/cuda_hip/components/cooperative_groups.hpp"
1314
#include "common/cuda_hip/components/intrinsics.hpp"
1415
#include "common/cuda_hip/components/memory.hpp"
1516
#include "common/cuda_hip/components/merging.hpp"
@@ -105,7 +106,7 @@ __global__ __launch_bounds__(default_block_size) void sweep(
105106
load_relaxed(ut_vals + (ut_idx + ut_col_begin));
106107
}
107108
// remember the transposed element
108-
auto found_transp = subwarp.ballot(ut_row == row);
109+
auto found_transp = group::ballot(subwarp, ut_row == row);
109110
if (found_transp) {
110111
ut_nz =
111112
subwarp.shfl(ut_idx + ut_col_begin, ffs(found_transp) - 1);

common/cuda_hip/matrix/csr_kernels.template.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ __global__ __launch_bounds__(default_block_size) void spgeam_nnz(
609609
a_col_idxs + a_begin, a_size, b_col_idxs + b_begin, b_size, subwarp,
610610
[&](IndexType, IndexType a_col, IndexType, IndexType b_col, IndexType,
611611
bool valid) {
612-
count += popcnt(subwarp.ballot(a_col != b_col && valid));
612+
count += popcnt(group::ballot(subwarp, a_col != b_col && valid));
613613
return true;
614614
});
615615

@@ -657,7 +657,7 @@ __global__ __launch_bounds__(default_block_size) void spgeam(
657657
[&](IndexType a_nz, IndexType a_col, IndexType b_nz, IndexType b_col,
658658
IndexType, bool valid) {
659659
auto c_col = min(a_col, b_col);
660-
auto equal_mask = subwarp.ballot(a_col == b_col && valid);
660+
auto equal_mask = group::ballot(subwarp, a_col == b_col && valid);
661661
// check if the elements in the previous merge step are
662662
// equal
663663
auto prev_equal_mask = equal_mask << 1 | skip_first;

0 commit comments

Comments
 (0)