1212
1313#include " common/cuda_hip/base/math.hpp"
1414#include " common/cuda_hip/base/runtime.hpp"
15+ #include " common/cuda_hip/components/cooperative_groups.hpp"
1516#include " common/cuda_hip/components/intrinsics.hpp"
1617#include " common/cuda_hip/components/memory.hpp"
1718#include " common/cuda_hip/components/merging.hpp"
@@ -75,8 +76,8 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_nnz(
7576 IndexType llh_col, IndexType out_nz, bool valid) {
7677 auto col = min (a_col, llh_col);
7778 // count the number of unique elements being merged
78- count +=
79- popcnt ( subwarp. ballot ( col <= row && a_col != llh_col && valid));
79+ count += popcnt ( group::ballot (
80+ subwarp, col <= row && a_col != llh_col && valid));
8081 return true ;
8182 });
8283 if (subwarp.thread_rank () == 0 ) {
@@ -149,7 +150,8 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_init(
149150 auto llh_cur_val = subwarp.shfl (llh_val, merge_result.b_idx );
150151 auto valid = out_begin + lane < out_size;
151152 // check if the previous thread has matching columns
152- auto equal_mask = subwarp.ballot (a_cur_col == llh_cur_col && valid);
153+ auto equal_mask =
154+ group::ballot (subwarp, a_cur_col == llh_cur_col && valid);
153155 auto prev_equal_mask = equal_mask << 1 | skip_first;
154156 skip_first = bool (equal_mask >> (subwarp_size - 1 ));
155157 auto prev_equal = bool (prev_equal_mask & lanemask_eq);
@@ -179,7 +181,7 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_init(
179181 // determine which threads will write output to L
180182 auto use_l = l_cur_col == r_col;
181183 auto do_write = !prev_equal && valid && r_col <= row;
182- auto l_new_advance_mask = subwarp. ballot (do_write);
184+ auto l_new_advance_mask = group:: ballot (subwarp, do_write);
183185 // store values
184186 if (do_write) {
185187 auto diag = l_vals[l_row_ptrs[r_col + 1 ] - 1 ];
@@ -192,7 +194,7 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_init(
192194 // advance *_begin offsets
193195 auto a_advance = merge_result.a_advance ;
194196 auto llh_advance = merge_result.b_advance ;
195- auto l_advance = popcnt (subwarp. ballot (do_write && use_l));
197+ auto l_advance = popcnt (group:: ballot (subwarp, do_write && use_l));
196198 auto l_new_advance = popcnt (l_new_advance_mask);
197199 a_begin += a_advance;
198200 llh_begin += llh_advance;
@@ -295,7 +297,7 @@ __global__ __launch_bounds__(default_block_size) void ict_sweep(
295297 conj (load_relaxed (l_vals + (lh_idx + lh_col_begin)));
296298 }
297299 // remember the transposed element
298- auto found_transp = subwarp. ballot (lh_row == row);
300+ auto found_transp = group:: ballot (subwarp, lh_row == row);
299301 if (found_transp) {
300302 lh_nz =
301303 subwarp.shfl (lh_idx + lh_col_begin, ffs (found_transp) - 1 );
0 commit comments