Skip to content

Commit a0e9b0c

Browse files
committed
update tests and fix issues
1 parent f69f1d8 commit a0e9b0c

20 files changed

Lines changed: 626 additions & 332 deletions

.github/workflows/publish.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ jobs:
147147
LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
148148
PATH=/usr/local/cuda/bin:$PATH
149149
CIBW_BEFORE_BUILD: >
150+
rm -f build/.librmm_dir &&
151+
mkdir -p build &&
150152
python -m pip install -U pip
151153
scikit-build-core cmake ninja nanobind
152154
librmm-cu${{ matrix.cuda_major }} &&
@@ -157,8 +159,8 @@ jobs:
157159
ln -sf "$RMM_ROOT/lib64/librmm.so" /usr/local/lib/librmm.so &&
158160
ln -sf "$LOG_ROOT/lib64/librapids_logger.so" /usr/local/lib/librapids_logger.so &&
159161
ldconfig &&
160-
python -c "import librmm; print(librmm.__path__[0])" > /tmp/.librmm_dir &&
161-
echo "[rsc-build] marker=$(cat /tmp/.librmm_dir)"
162+
python -c "import librmm; print(librmm.__path__[0])" > build/.librmm_dir &&
163+
echo "[rsc-build] marker=$(cat build/.librmm_dir)"
162164
CIBW_TEST_SKIP: "*"
163165
CIBW_TEST_COMMAND: ""
164166
CIBW_REPAIR_WHEEL_COMMAND: "auditwheel repair --exclude libcublas.so.${{ matrix.cuda_major }} --exclude libcublasLt.so.${{ matrix.cuda_major }} --exclude libcudart.so.${{ matrix.cuda_major }} --exclude librmm.so --exclude librapids_logger.so -w {dest_dir} {wheel}"

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ CLAUDE.md
5151

5252
# tmp_scripts
5353
tmp_scripts/
54-
benchmarks/
54+
/benchmarks/

CMakeLists.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,19 @@ if (RSC_BUILD_EXTENSIONS)
5050
if (RSC_PYTHON_RMM_DIR AND EXISTS "${RSC_PYTHON_RMM_DIR}/rmm-config.cmake")
5151
list(APPEND RSC_RMM_HINTS "${RSC_PYTHON_RMM_DIR}")
5252
endif()
53-
if(EXISTS "/tmp/.librmm_dir")
54-
file(READ "/tmp/.librmm_dir" _rsc_librmm_marker)
53+
# Wheel builds install librmm/rapids_logger into the isolated build env and
54+
# write build/.librmm_dir from CIBW_BEFORE_BUILD. publish.yml also symlinks
55+
# those shared libraries into /usr/local/lib so auditwheel can see and exclude
56+
# them instead of bundling RAPIDS runtime libraries into the wheel.
57+
if(DEFINED ENV{RSC_LIBRMM_DIR} AND EXISTS "$ENV{RSC_LIBRMM_DIR}/lib64/cmake/rmm/rmm-config.cmake")
58+
set(_rsc_librmm_marker "$ENV{RSC_LIBRMM_DIR}")
59+
elseif(EXISTS "${CMAKE_SOURCE_DIR}/build/.librmm_dir")
60+
file(READ "${CMAKE_SOURCE_DIR}/build/.librmm_dir" _rsc_librmm_marker)
5561
string(STRIP "${_rsc_librmm_marker}" _rsc_librmm_marker)
62+
else()
63+
set(_rsc_librmm_marker "")
64+
endif()
65+
if(NOT "${_rsc_librmm_marker}" STREQUAL "" AND EXISTS "${_rsc_librmm_marker}/lib64/cmake/rmm/rmm-config.cmake")
5666
file(GLOB _rsc_marker_rmm_dirs "${_rsc_librmm_marker}/lib64/cmake/rmm")
5767
file(GLOB _rsc_marker_rapids_prefixes
5868
"${_rsc_librmm_marker}/lib64"

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ requires = [
44
"nanobind>=2.0.0",
55
"setuptools-scm>=8",
66
# librmm headers/CMake config are needed at build time for Wilcoxon.
7-
# CUDA wheel builds rewrite this to the matching cu12/cu13 package.
7+
# Generic isolated source builds default to CUDA 12. CUDA wheel builds
8+
# rewrite this to the matching cu12/cu13 package; CUDA 13 source builds
9+
# should build in an existing RAPIDS env with --no-build-isolation.
810
"librmm-cu12>=25.10",
911
]
1012
build-backend = "scikit_build_core.build"

src/rapids_singlecell/_cuda/wilcoxon/kernels_wilcoxon_ovo.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <cuda_runtime.h>
44

5+
#include "wilcoxon_fast_common.cuh"
6+
57
// ============================================================================
68
// Warp reduction helper (sum doubles across block via warp_buf)
79
// ============================================================================

src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,14 @@ static void launch_ovr_rank_dense_streaming(
4141
}
4242

4343
size_t sub_items = (size_t)n_rows * sub_batch_cols;
44-
if (sub_items > (size_t)std::numeric_limits<int>::max()) {
45-
throw std::runtime_error(
46-
"Dense OVR sub-batch exceeds CUB int item limit");
47-
}
44+
int sub_items_i32 = checked_cub_items(sub_items, "Dense OVR sub-batch");
4845

4946
size_t cub_temp_bytes = 0;
5047
{
5148
auto* fk = reinterpret_cast<float*>(1);
5249
auto* iv = reinterpret_cast<int*>(1);
5350
cub::DeviceSegmentedRadixSort::SortPairs(
54-
nullptr, cub_temp_bytes, fk, fk, iv, iv, (int)sub_items,
51+
nullptr, cub_temp_bytes, fk, fk, iv, iv, sub_items_i32,
5552
sub_batch_cols, iv, iv + 1, BEGIN_BIT, END_BIT);
5653
}
5754

@@ -97,7 +94,8 @@ static void launch_ovr_rank_dense_streaming(
9794
int batch_idx = 0;
9895
while (col < n_cols) {
9996
int sb_cols = std::min(sub_batch_cols, n_cols - col);
100-
int sb_items = n_rows * sb_cols;
97+
int sb_items = checked_int_product((size_t)n_rows, (size_t)sb_cols,
98+
"Dense OVR active sub-batch");
10199
int s = batch_idx % n_streams;
102100
cudaStream_t stream = streams[s];
103101
auto& buf = bufs[s];
@@ -184,32 +182,30 @@ static void launch_ovo_rank_dense_tiered_impl(
184182
n_streams = (n_cols + sub_batch_cols - 1) / sub_batch_cols;
185183

186184
size_t sub_ref_items = (size_t)n_ref * sub_batch_cols;
187-
if (sub_ref_items > (size_t)std::numeric_limits<int>::max()) {
188-
throw std::runtime_error(
189-
"Dense OVO reference sub-batch exceeds CUB int item limit");
190-
}
185+
int sub_ref_items_i32 =
186+
checked_cub_items(sub_ref_items, "Dense OVO reference sub-batch");
191187

192188
size_t sub_grp_items = (size_t)n_all_grp * sub_batch_cols;
193-
if (sub_grp_items > (size_t)std::numeric_limits<int>::max()) {
194-
throw std::runtime_error(
195-
"Dense OVO sub-batch exceeds CUB int item limit");
196-
}
189+
int sub_grp_items_i32 =
190+
checked_cub_items(sub_grp_items, "Dense OVO group sub-batch");
197191

198192
size_t grp_cub_temp_bytes = 0;
199193
if (needs_tier3) {
200-
int max_grp_seg = n_sort_groups * sub_batch_cols;
194+
int max_grp_seg =
195+
checked_int_product((size_t)n_sort_groups, (size_t)sub_batch_cols,
196+
"Dense OVO group segment count");
201197
auto* fk = reinterpret_cast<float*>(1);
202198
auto* doff = reinterpret_cast<int*>(1);
203199
cub::DeviceSegmentedRadixSort::SortKeys(
204-
nullptr, grp_cub_temp_bytes, fk, fk, (int)sub_grp_items,
205-
max_grp_seg, doff, doff + 1, BEGIN_BIT, END_BIT);
200+
nullptr, grp_cub_temp_bytes, fk, fk, sub_grp_items_i32, max_grp_seg,
201+
doff, doff + 1, BEGIN_BIT, END_BIT);
206202
}
207203
size_t ref_cub_temp_bytes = 0;
208204
if (!ref_is_sorted) {
209205
auto* fk = reinterpret_cast<float*>(1);
210206
auto* doff = reinterpret_cast<int*>(1);
211207
cub::DeviceSegmentedRadixSort::SortKeys(
212-
nullptr, ref_cub_temp_bytes, fk, fk, (int)sub_ref_items,
208+
nullptr, ref_cub_temp_bytes, fk, fk, sub_ref_items_i32,
213209
sub_batch_cols, doff, doff + 1, BEGIN_BIT, END_BIT);
214210
}
215211

@@ -270,7 +266,9 @@ static void launch_ovo_rank_dense_tiered_impl(
270266
pool.alloc<double>((size_t)n_groups * sub_batch_cols);
271267
if (needs_tier3) {
272268
bufs[s].grp_sorted = pool.alloc<float>(sub_grp_items);
273-
int max_seg = n_sort_groups * sub_batch_cols;
269+
int max_seg = checked_int_product((size_t)n_sort_groups,
270+
(size_t)sub_batch_cols,
271+
"Dense OVO group segment buffer");
274272
bufs[s].grp_seg_offsets = pool.alloc<int>(max_seg);
275273
bufs[s].grp_seg_ends = pool.alloc<int>(max_seg);
276274
} else {
@@ -287,8 +285,12 @@ static void launch_ovo_rank_dense_tiered_impl(
287285
int batch_idx = 0;
288286
while (col < n_cols) {
289287
int sb_cols = std::min(sub_batch_cols, n_cols - col);
290-
int sb_ref_items_actual = n_ref * sb_cols;
291-
int sb_grp_items_actual = n_all_grp * sb_cols;
288+
int sb_ref_items_actual =
289+
checked_int_product((size_t)n_ref, (size_t)sb_cols,
290+
"Dense OVO active reference sub-batch");
291+
int sb_grp_items_actual =
292+
checked_int_product((size_t)n_all_grp, (size_t)sb_cols,
293+
"Dense OVO active group sub-batch");
292294
int s = batch_idx % n_streams;
293295
cudaStream_t stream = streams[s];
294296
auto& buf = bufs[s];
@@ -343,7 +345,9 @@ static void launch_ovo_rank_dense_tiered_impl(
343345
compute_tie_corr, padded_grp_size, upper_skip_le);
344346
CUDA_CHECK_LAST_ERROR(ovo_fused_sort_rank_kernel);
345347
} else if (needs_tier3) {
346-
int sb_grp_seg = n_sort_groups * sb_cols;
348+
int sb_grp_seg =
349+
checked_int_product((size_t)n_sort_groups, (size_t)sb_cols,
350+
"Dense OVO active group segment count");
347351
int blk = (sb_grp_seg + UTIL_BLOCK_SIZE - 1) / UTIL_BLOCK_SIZE;
348352
build_tier3_seg_begin_end_offsets_kernel<<<blk, UTIL_BLOCK_SIZE, 0,
349353
stream>>>(

src/rapids_singlecell/_cuda/wilcoxon/wilcoxon_fast_common.cuh

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <cstdint>
4+
#include <limits>
45
#include <stdexcept>
56
#include <string>
67
#include <vector>
@@ -48,6 +49,39 @@ constexpr int TIER1_GROUP_THRESHOLD = 2500;
4849
// 512 MB per stream dense slab + same for sorted copy ≈ 1 GB / stream.
4950
constexpr size_t GROUP_DENSE_BUDGET_ITEMS = 128 * 1024 * 1024;
5051

52+
static inline size_t wilcoxon_max_smem_per_block() {
53+
int device = 0;
54+
int max_smem = 0;
55+
cudaGetDevice(&device);
56+
cudaDeviceGetAttribute(&max_smem, cudaDevAttrMaxSharedMemoryPerBlock,
57+
device);
58+
return (size_t)max_smem;
59+
}
60+
61+
static inline int checked_cub_items(size_t count, const char* context) {
62+
if (count > (size_t)std::numeric_limits<int>::max()) {
63+
throw std::runtime_error(std::string(context) +
64+
" exceeds CUB int item limit");
65+
}
66+
return (int)count;
67+
}
68+
69+
static inline int checked_int_span(size_t count, const char* context) {
70+
if (count > (size_t)std::numeric_limits<int>::max()) {
71+
throw std::runtime_error(std::string(context) +
72+
" exceeds int32 offset limit");
73+
}
74+
return (int)count;
75+
}
76+
77+
static inline int checked_int_product(size_t a, size_t b, const char* context) {
78+
if (a != 0 && b > (size_t)std::numeric_limits<int>::max() / a) {
79+
throw std::runtime_error(std::string(context) +
80+
" exceeds int32 item limit");
81+
}
82+
return (int)(a * b);
83+
}
84+
5185
// ---------------------------------------------------------------------------
5286
// RAII guard for cudaHostRegister. Unregisters on scope exit even when an
5387
// exception unwinds — prevents leaked host pinning on stream-sync failures.
@@ -60,9 +94,9 @@ struct HostRegisterGuard {
6094
if (p && bytes > 0) {
6195
cudaError_t err = cudaHostRegister(p, bytes, flags);
6296
if (err != cudaSuccess) {
63-
// Already-registered memory is fine; anything else means the
64-
// subsequent kernels would read garbage from an unmapped
65-
// pointer, so surface the error immediately.
97+
// Already-registered memory belongs to another owner; use it
98+
// without unregistering here. Other failures mean mapped reads
99+
// would be unsafe, so surface them immediately.
66100
if (err == cudaErrorHostMemoryAlreadyRegistered) {
67101
cudaGetLastError(); // clear sticky error flag
68102
} else {
@@ -116,6 +150,10 @@ struct RmmScratchPool {
116150
template <typename T>
117151
T* alloc(size_t count) {
118152
if (count == 0) count = 1;
153+
if (count > std::numeric_limits<size_t>::max() / sizeof(T)) {
154+
throw std::runtime_error(
155+
"Wilcoxon scratch allocation size overflow");
156+
}
119157
size_t bytes = count * sizeof(T);
120158
void* ptr = wilcoxon_rmm_allocate(bytes);
121159
bufs.push_back({ptr, bytes});

0 commit comments

Comments
 (0)