Skip to content

Commit 4df6b7a

Browse files
committed
Inital fix
1 parent 2afa3c6 commit 4df6b7a

5 files changed

Lines changed: 220 additions & 19 deletions

File tree

src/rapids_singlecell/_cuda/bbknn/bbknn.cu

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,48 @@
66
using namespace nb::literals;
77

88
constexpr int BLOCK_SIZE = 64;
9+
// Block-cooperative sort kernel: BLOCK_THREADS * ITEMS_PER_THREAD = 1024.
10+
// Rows larger than this must use the per-thread kernel (kernel 1).
11+
constexpr int SORT_BLOCK_THREADS = 128;
12+
constexpr int SORT_ITEMS_PER_THREAD = 8;
13+
constexpr int SORT_TILE_SIZE = SORT_BLOCK_THREADS * SORT_ITEMS_PER_THREAD;
914

1015
static inline void launch_find_top_k_per_row(const float* data,
1116
const int* indptr, int n_rows,
1217
int trim, float* vals,
1318
cudaStream_t stream) {
14-
dim3 block(BLOCK_SIZE);
15-
dim3 grid((n_rows + BLOCK_SIZE - 1) / BLOCK_SIZE);
16-
size_t shared_mem_size = static_cast<size_t>(BLOCK_SIZE) *
17-
static_cast<size_t>(trim) * sizeof(float);
19+
// Each thread keeps its row's top-`trim` values in shared memory, so the
20+
// per-block shared-mem request is BLOCK_SIZE * trim * sizeof(float).
21+
// The default per-block shared-mem cap is ~48 KB; halve the block size
22+
// until the request fits so the launch succeeds for any reasonable trim.
23+
constexpr size_t SHARED_MEM_BUDGET = 48 * 1024;
24+
const size_t per_thread_bytes = static_cast<size_t>(trim) * sizeof(float);
25+
int block_size = BLOCK_SIZE;
26+
while (block_size > 1 &&
27+
static_cast<size_t>(block_size) * per_thread_bytes >
28+
SHARED_MEM_BUDGET) {
29+
block_size /= 2;
30+
}
31+
dim3 block(block_size);
32+
dim3 grid((n_rows + block_size - 1) / block_size);
33+
size_t shared_mem_size = static_cast<size_t>(block_size) * per_thread_bytes;
1834
find_top_k_per_row_kernel<<<grid, block, shared_mem_size, stream>>>(
1935
data, indptr, n_rows, trim, vals);
2036
CUDA_CHECK_LAST_ERROR(find_top_k_per_row_kernel);
2137
}
2238

39+
static inline void launch_find_top_k_per_row_sorted(const float* data,
40+
const int* indptr,
41+
int n_rows, int trim,
42+
float* vals,
43+
cudaStream_t stream) {
44+
dim3 block(SORT_BLOCK_THREADS);
45+
dim3 grid(n_rows);
46+
find_top_k_per_row_sorted_kernel<SORT_BLOCK_THREADS, SORT_ITEMS_PER_THREAD>
47+
<<<grid, block, 0, stream>>>(data, indptr, n_rows, trim, vals);
48+
CUDA_CHECK_LAST_ERROR(find_top_k_per_row_sorted_kernel);
49+
}
50+
2351
static inline void launch_cut_smaller(int* indptr, int* index, float* data,
2452
float* vals, int n_rows,
2553
cudaStream_t stream) {
@@ -43,6 +71,20 @@ void register_bindings(nb::module_& m) {
4371
"data"_a, "indptr"_a, nb::kw_only(), "n_rows"_a, "trim"_a, "vals"_a,
4472
"stream"_a = 0);
4573

74+
m.def(
75+
"find_top_k_per_row_sorted",
76+
[](gpu_array_c<const float, Device> data,
77+
gpu_array_c<const int, Device> indptr, int n_rows, int trim,
78+
gpu_array_c<float, Device> vals, std::uintptr_t stream) {
79+
launch_find_top_k_per_row_sorted(data.data(), indptr.data(), n_rows,
80+
trim, vals.data(),
81+
(cudaStream_t)stream);
82+
},
83+
"data"_a, "indptr"_a, nb::kw_only(), "n_rows"_a, "trim"_a, "vals"_a,
84+
"stream"_a = 0);
85+
86+
m.def("sort_tile_size", []() { return SORT_TILE_SIZE; });
87+
4688
m.def(
4789
"cut_smaller",
4890
[](gpu_array_c<int, Device> indptr, gpu_array_c<int, Device> index,

src/rapids_singlecell/_cuda/bbknn/kernels_bbknn.cuh

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

3+
#include <cub/block/block_radix_sort.cuh>
34
#include <cuda_runtime.h>
5+
#include <math_constants.h>
46

57
__global__ void find_top_k_per_row_kernel(const float* __restrict__ data,
68
const int* __restrict__ indptr,
@@ -49,6 +51,59 @@ __global__ void find_top_k_per_row_kernel(const float* __restrict__ data,
4951
vals[row] = top_k[min_index];
5052
}
5153

54+
// Block-cooperative variant: one CUDA block per row, sorts the row with
55+
// BlockRadixSort, returns the `trim`-th largest as the cut value. Shared
56+
// memory is the CUB sort temp storage only, independent of `trim`, so it
57+
// scales to large `trim` values where the per-thread top-k kernel runs out
58+
// of shared memory. Requires every row to fit in BLOCK_THREADS *
59+
// ITEMS_PER_THREAD.
60+
template <int BLOCK_THREADS, int ITEMS_PER_THREAD>
61+
__global__ void find_top_k_per_row_sorted_kernel(const float* __restrict__ data,
62+
const int* __restrict__ indptr,
63+
const int n_rows,
64+
const int trim,
65+
float* __restrict__ vals) {
66+
int row = blockIdx.x;
67+
if (row >= n_rows) {
68+
return;
69+
}
70+
71+
int start = indptr[row];
72+
int end = indptr[row + 1];
73+
int length = end - start;
74+
75+
if (length <= trim) {
76+
if (threadIdx.x == 0) {
77+
vals[row] = 0.0f; // insufficient elements
78+
}
79+
return;
80+
}
81+
82+
using BlockRadixSort =
83+
cub::BlockRadixSort<float, BLOCK_THREADS, ITEMS_PER_THREAD>;
84+
__shared__ typename BlockRadixSort::TempStorage temp_storage;
85+
86+
float thread_keys[ITEMS_PER_THREAD];
87+
#pragma unroll
88+
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
89+
int idx = threadIdx.x * ITEMS_PER_THREAD + i;
90+
// Pad out-of-range with -inf so they sort to the bottom of a
91+
// descending sort and never appear among the trim largest.
92+
thread_keys[i] = (idx < length) ? data[start + idx] : -CUDART_INF_F;
93+
}
94+
95+
BlockRadixSort(temp_storage).SortDescending(thread_keys);
96+
97+
// After SortDescending with blocked arrangement, sorted index i lives at
98+
// thread (i / ITEMS_PER_THREAD), local slot (i % ITEMS_PER_THREAD).
99+
int target_idx = trim - 1;
100+
int target_thread = target_idx / ITEMS_PER_THREAD;
101+
int target_item = target_idx % ITEMS_PER_THREAD;
102+
if (threadIdx.x == target_thread) {
103+
vals[row] = thread_keys[target_item];
104+
}
105+
}
106+
52107
__global__ void cut_smaller_kernel(const int* __restrict__ indptr,
53108
const int* __restrict__ index,
54109
float* __restrict__ data,

src/rapids_singlecell/preprocessing/_neighbors/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,16 @@ def bbknn(
407407
knn_indices[:, col_range] = ind_to[sub_ind]
408408
knn_dist[:, col_range] = sub_dist
409409

410+
# Sort each row so neighbors are ordered closest-first across all batches.
411+
# fuzzy_simplicial_set uses the first non-zero distance per row as the
412+
# local-connectivity rho; unsorted input collapses sigma and weights.
413+
order = cp.argsort(knn_dist, axis=1)
414+
row_idx = cp.arange(n_obs)[:, None]
415+
knn_dist = knn_dist[row_idx, order]
416+
knn_indices = knn_indices[row_idx, order]
417+
410418
if trim is None:
411-
trim = 10 * neighbors_within_batch
419+
trim = 10 * total_neighbors
412420

413421
params = dict(
414422
n_neighbors=total_neighbors,

src/rapids_singlecell/preprocessing/_neighbors/_helper/__init__.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,30 +130,61 @@ def _fix_self_distances(knn_dist: cp.ndarray, metric: _Metrics) -> cp.ndarray:
130130
return knn_dist
131131

132132

133-
def _trimming(cnts: cp_sparse.csr_matrix, trim: int) -> cp_sparse.csr_matrix:
133+
_TRIM_SORT_THRESHOLD = 100
134+
135+
136+
def _trimming(
137+
cnts: cp_sparse.csr_matrix,
138+
trim: int,
139+
*,
140+
kernel: str = "auto",
141+
) -> cp_sparse.csr_matrix:
134142
from rapids_singlecell._cuda._bbknn_cuda import (
135143
cut_smaller,
136144
find_top_k_per_row,
145+
find_top_k_per_row_sorted,
146+
sort_tile_size,
137147
)
138148

139149
n_rows = cnts.shape[0]
140150
vals_gpu = cp.zeros(n_rows, dtype=cp.float32)
151+
stream = cp.cuda.get_current_stream().ptr
152+
153+
if kernel == "auto":
154+
if trim >= _TRIM_SORT_THRESHOLD:
155+
max_row_nnz = int(cp.diff(cnts.indptr).max().get())
156+
kernel = "sorted" if max_row_nnz <= sort_tile_size() else "thread"
157+
else:
158+
kernel = "thread"
159+
160+
if kernel == "sorted":
161+
find_top_k_per_row_sorted(
162+
cnts.data,
163+
cnts.indptr,
164+
n_rows=n_rows,
165+
trim=trim,
166+
vals=vals_gpu,
167+
stream=stream,
168+
)
169+
elif kernel == "thread":
170+
find_top_k_per_row(
171+
cnts.data,
172+
cnts.indptr,
173+
n_rows=n_rows,
174+
trim=trim,
175+
vals=vals_gpu,
176+
stream=stream,
177+
)
178+
else:
179+
raise ValueError(f"Unknown trim kernel: {kernel!r}")
141180

142-
find_top_k_per_row(
143-
cnts.data,
144-
cnts.indptr,
145-
n_rows=cnts.shape[0],
146-
trim=trim,
147-
vals=vals_gpu,
148-
stream=cp.cuda.get_current_stream().ptr,
149-
)
150181
cut_smaller(
151182
cnts.indptr,
152183
cnts.indices,
153184
cnts.data,
154185
vals=vals_gpu,
155-
n_rows=cnts.shape[0],
156-
stream=cp.cuda.get_current_stream().ptr,
186+
n_rows=n_rows,
187+
stream=stream,
157188
)
158189
cnts.eliminate_zeros()
159190
return cnts

tests/test_neighbors.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,78 @@ def test_bbknn():
144144
assert counter / b_stop > 0.9
145145

146146

147-
def test_trimming():
147+
def test_bbknn_distances_sorted_per_row():
148+
# fuzzy_simplicial_set uses the first non-zero distance per row as rho;
149+
# unsorted per-batch columns break sigma estimation and collapse weights.
150+
adata = pbmc68k_reduced()
151+
bbknn(adata, n_pcs=15, batch_key="phase", algorithm="brute")
152+
dists = adata.obsp["distances"]
153+
for start, stop in itertools.pairwise(dists.indptr):
154+
row = dists.data[start:stop]
155+
assert np.all(np.diff(row) >= 0), "bbknn distance rows must be sorted ascending"
156+
157+
158+
def test_bbknn_connectivities_not_collapsed():
159+
# Regression: before the per-row sort fix, mean connectivity on this
160+
# dataset was ~0.85 with most weights pinned near 1.0. With sorted input
161+
# the distribution spreads out properly.
162+
adata = pbmc68k_reduced()
163+
bbknn(adata, n_pcs=15, batch_key="phase", algorithm="brute")
164+
weights = adata.obsp["connectivities"].data
165+
assert weights.mean() < 0.7
166+
assert (weights > 0.99).mean() < 0.5
167+
168+
169+
def test_bbknn_trim_default_matches_upstream():
170+
# bbknn upstream defaults trim = 10 * total_neighbors
171+
# (= 10 * neighbors_within_batch * n_batches).
172+
adata = pbmc68k_reduced()
173+
n_batches = adata.obs["phase"].nunique()
174+
neighbors_within_batch = 3
175+
bbknn(
176+
adata,
177+
n_pcs=15,
178+
batch_key="phase",
179+
algorithm="brute",
180+
neighbors_within_batch=neighbors_within_batch,
181+
)
182+
assert (
183+
adata.uns["neighbors"]["params"]["trim"]
184+
== 10 * neighbors_within_batch * n_batches
185+
)
186+
187+
188+
@pytest.mark.parametrize("trim", [5, 240])
189+
def test_trimming(trim):
190+
# trim=5: typical case.
191+
# trim=240: exercises the kernel's adaptive block-size path. A static
192+
# BLOCK_SIZE=64 would request 60 KB of dynamic shared memory and fail to
193+
# launch (default per-block cap is ~48 KB).
194+
adata = pbmc68k_reduced()
195+
cnts_gpu = X_to_GPU(adata.obsp["connectivities"]).astype(np.float32)
196+
cnts_cpu = adata.obsp["connectivities"].astype(np.float32)
197+
198+
cnts_cpu = trimming_cpu(cnts_cpu, trim)
199+
cnts_gpu = trimming_gpu(cnts_gpu, trim)
200+
201+
cp.testing.assert_array_equal(cnts_cpu.data, cnts_gpu.data)
202+
cp.testing.assert_array_equal(cnts_cpu.indices, cnts_gpu.indices)
203+
cp.testing.assert_array_equal(cnts_cpu.indptr, cnts_gpu.indptr)
204+
205+
206+
@pytest.mark.parametrize("trim", [5, 50, 240])
207+
@pytest.mark.parametrize("kernel", ["thread", "sorted"])
208+
def test_trimming_kernels_agree(trim, kernel):
209+
# Both trim kernels must produce identical results to the CPU reference
210+
# (bbknn.matrix.trimming) on the same input. The "thread" kernel keeps a
211+
# per-thread top-k in shared memory; the "sorted" kernel does one block
212+
# per row with BlockRadixSort.
148213
adata = pbmc68k_reduced()
149214
cnts_gpu = X_to_GPU(adata.obsp["connectivities"]).astype(np.float32)
150215
cnts_cpu = adata.obsp["connectivities"].astype(np.float32)
151216

152-
cnts_cpu = trimming_cpu(cnts_cpu, 5)
153-
cnts_gpu = trimming_gpu(cnts_gpu, 5)
217+
cnts_cpu = trimming_cpu(cnts_cpu, trim)
218+
cnts_gpu = trimming_gpu(cnts_gpu, trim, kernel=kernel)
154219

155220
cp.testing.assert_array_equal(cnts_cpu.data, cnts_gpu.data)
156221
cp.testing.assert_array_equal(cnts_cpu.indices, cnts_gpu.indices)

0 commit comments

Comments
 (0)