@@ -41,17 +41,14 @@ static void launch_ovr_rank_dense_streaming(
4141 }
4242
4343 size_t sub_items = (size_t )n_rows * sub_batch_cols;
44- if (sub_items > (size_t )std::numeric_limits<int >::max ()) {
45- throw std::runtime_error (
46- " Dense OVR sub-batch exceeds CUB int item limit" );
47- }
44+ int sub_items_i32 = checked_cub_items (sub_items, " Dense OVR sub-batch" );
4845
4946 size_t cub_temp_bytes = 0 ;
5047 {
5148 auto * fk = reinterpret_cast <float *>(1 );
5249 auto * iv = reinterpret_cast <int *>(1 );
5350 cub::DeviceSegmentedRadixSort::SortPairs (
54- nullptr , cub_temp_bytes, fk, fk, iv, iv, ( int )sub_items ,
51+ nullptr , cub_temp_bytes, fk, fk, iv, iv, sub_items_i32 ,
5552 sub_batch_cols, iv, iv + 1 , BEGIN_BIT, END_BIT);
5653 }
5754
@@ -97,7 +94,8 @@ static void launch_ovr_rank_dense_streaming(
9794 int batch_idx = 0 ;
9895 while (col < n_cols) {
9996 int sb_cols = std::min (sub_batch_cols, n_cols - col);
100- int sb_items = n_rows * sb_cols;
97+ int sb_items = checked_int_product ((size_t )n_rows, (size_t )sb_cols,
98+ " Dense OVR active sub-batch" );
10199 int s = batch_idx % n_streams;
102100 cudaStream_t stream = streams[s];
103101 auto & buf = bufs[s];
@@ -184,32 +182,30 @@ static void launch_ovo_rank_dense_tiered_impl(
184182 n_streams = (n_cols + sub_batch_cols - 1 ) / sub_batch_cols;
185183
186184 size_t sub_ref_items = (size_t )n_ref * sub_batch_cols;
187- if (sub_ref_items > (size_t )std::numeric_limits<int >::max ()) {
188- throw std::runtime_error (
189- " Dense OVO reference sub-batch exceeds CUB int item limit" );
190- }
185+ int sub_ref_items_i32 =
186+ checked_cub_items (sub_ref_items, " Dense OVO reference sub-batch" );
191187
192188 size_t sub_grp_items = (size_t )n_all_grp * sub_batch_cols;
193- if (sub_grp_items > (size_t )std::numeric_limits<int >::max ()) {
194- throw std::runtime_error (
195- " Dense OVO sub-batch exceeds CUB int item limit" );
196- }
189+ int sub_grp_items_i32 =
190+ checked_cub_items (sub_grp_items, " Dense OVO group sub-batch" );
197191
198192 size_t grp_cub_temp_bytes = 0 ;
199193 if (needs_tier3) {
200- int max_grp_seg = n_sort_groups * sub_batch_cols;
194+ int max_grp_seg =
195+ checked_int_product ((size_t )n_sort_groups, (size_t )sub_batch_cols,
196+ " Dense OVO group segment count" );
201197 auto * fk = reinterpret_cast <float *>(1 );
202198 auto * doff = reinterpret_cast <int *>(1 );
203199 cub::DeviceSegmentedRadixSort::SortKeys (
204- nullptr , grp_cub_temp_bytes, fk, fk, ( int )sub_grp_items ,
205- max_grp_seg, doff, doff + 1 , BEGIN_BIT, END_BIT);
200+ nullptr , grp_cub_temp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg ,
201+ doff, doff + 1 , BEGIN_BIT, END_BIT);
206202 }
207203 size_t ref_cub_temp_bytes = 0 ;
208204 if (!ref_is_sorted) {
209205 auto * fk = reinterpret_cast <float *>(1 );
210206 auto * doff = reinterpret_cast <int *>(1 );
211207 cub::DeviceSegmentedRadixSort::SortKeys (
212- nullptr , ref_cub_temp_bytes, fk, fk, ( int )sub_ref_items ,
208+ nullptr , ref_cub_temp_bytes, fk, fk, sub_ref_items_i32 ,
213209 sub_batch_cols, doff, doff + 1 , BEGIN_BIT, END_BIT);
214210 }
215211
@@ -270,7 +266,9 @@ static void launch_ovo_rank_dense_tiered_impl(
270266 pool.alloc <double >((size_t )n_groups * sub_batch_cols);
271267 if (needs_tier3) {
272268 bufs[s].grp_sorted = pool.alloc <float >(sub_grp_items);
273- int max_seg = n_sort_groups * sub_batch_cols;
269+ int max_seg = checked_int_product ((size_t )n_sort_groups,
270+ (size_t )sub_batch_cols,
271+ " Dense OVO group segment buffer" );
274272 bufs[s].grp_seg_offsets = pool.alloc <int >(max_seg);
275273 bufs[s].grp_seg_ends = pool.alloc <int >(max_seg);
276274 } else {
@@ -287,8 +285,12 @@ static void launch_ovo_rank_dense_tiered_impl(
287285 int batch_idx = 0 ;
288286 while (col < n_cols) {
289287 int sb_cols = std::min (sub_batch_cols, n_cols - col);
290- int sb_ref_items_actual = n_ref * sb_cols;
291- int sb_grp_items_actual = n_all_grp * sb_cols;
288+ int sb_ref_items_actual =
289+ checked_int_product ((size_t )n_ref, (size_t )sb_cols,
290+ " Dense OVO active reference sub-batch" );
291+ int sb_grp_items_actual =
292+ checked_int_product ((size_t )n_all_grp, (size_t )sb_cols,
293+ " Dense OVO active group sub-batch" );
292294 int s = batch_idx % n_streams;
293295 cudaStream_t stream = streams[s];
294296 auto & buf = bufs[s];
@@ -343,7 +345,9 @@ static void launch_ovo_rank_dense_tiered_impl(
343345 compute_tie_corr, padded_grp_size, upper_skip_le);
344346 CUDA_CHECK_LAST_ERROR (ovo_fused_sort_rank_kernel);
345347 } else if (needs_tier3) {
346- int sb_grp_seg = n_sort_groups * sb_cols;
348+ int sb_grp_seg =
349+ checked_int_product ((size_t )n_sort_groups, (size_t )sb_cols,
350+ " Dense OVO active group segment count" );
347351 int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1 ) / UTIL_BLOCK_SIZE;
348352 build_tier3_seg_begin_end_offsets_kernel<<<blk, UTIL_BLOCK_SIZE, 0 ,
349353 stream>>> (
0 commit comments