Skip to content

Commit 2926263

Browse files
authored
Fix BBKNN Trimming (#659)
* Inital fix * add release note and up blocksort size * adress comments
1 parent 42123b3 commit 2926263

6 files changed

Lines changed: 232 additions & 19 deletions

File tree

docs/release-notes/0.15.1.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
```{rubric} Bug fixes
99
```
1010
* Fixes `tl.rank_genes_groups` returning NaN/zero `logfoldchanges`/`pvals` with `groups=[subset]` and `reference='rest'` {pr}`651` {smaller}`S Dicks`
11+
* Fixes `pp.bbknn` connectivities diverging from upstream `bbknn`: per-batch neighbours are now sorted by distance before `fuzzy_simplicial_set` (so weights no longer collapse near 1.0), and the default `trim` matches upstream (`10 * neighbors_within_batch * n_batches`). Trimming kernel no longer crashes for large `trim`, and a new block-cooperative sort kernel is auto-dispatched for large `trim` for substantial speedups {pr}`659` {smaller}`S Dicks`
1112
* Fixes float64 precision loss in `pp.normalize_pearson_residuals` on CSR/CSC input {pr}`658` {smaller}`A Mikaeili & S Dicks`
1213

1314
```{rubric} Misc

src/rapids_singlecell/_cuda/bbknn/bbknn.cu

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,54 @@
66
using namespace nb::literals;
77

88
constexpr int BLOCK_SIZE = 64;
9+
// Block-cooperative sort kernel: BLOCK_THREADS * ITEMS_PER_THREAD = 2048.
10+
// Rows larger than this must use the per-thread kernel (kernel 1).
11+
constexpr int SORT_BLOCK_THREADS = 128;
12+
constexpr int SORT_ITEMS_PER_THREAD = 16;
13+
constexpr int SORT_TILE_SIZE = SORT_BLOCK_THREADS * SORT_ITEMS_PER_THREAD;
914

1015
static inline void launch_find_top_k_per_row(const float* data,
1116
const int* indptr, int n_rows,
1217
int trim, float* vals,
1318
cudaStream_t stream) {
14-
dim3 block(BLOCK_SIZE);
15-
dim3 grid((n_rows + BLOCK_SIZE - 1) / BLOCK_SIZE);
16-
size_t shared_mem_size = static_cast<size_t>(BLOCK_SIZE) *
17-
static_cast<size_t>(trim) * sizeof(float);
19+
// Each thread keeps its row's top-`trim` values in shared memory, so the
20+
// per-block shared-mem request is BLOCK_SIZE * trim * sizeof(float).
21+
// The default per-block shared-mem cap is ~48 KB; halve the block size
22+
// until the request fits so the launch succeeds for any reasonable trim.
23+
constexpr size_t SHARED_MEM_BUDGET = 48 * 1024;
24+
const size_t per_thread_bytes = static_cast<size_t>(trim) * sizeof(float);
25+
int block_size = BLOCK_SIZE;
26+
while (block_size > 1 &&
27+
static_cast<size_t>(block_size) * per_thread_bytes >
28+
SHARED_MEM_BUDGET) {
29+
block_size /= 2;
30+
}
31+
if (static_cast<size_t>(block_size) * per_thread_bytes >
32+
SHARED_MEM_BUDGET) {
33+
throw std::runtime_error(
34+
"find_top_k_per_row: trim too large for shared-memory budget; "
35+
"use find_top_k_per_row_sorted instead");
36+
}
37+
dim3 block(block_size);
38+
dim3 grid((n_rows + block_size - 1) / block_size);
39+
size_t shared_mem_size = static_cast<size_t>(block_size) * per_thread_bytes;
1840
find_top_k_per_row_kernel<<<grid, block, shared_mem_size, stream>>>(
1941
data, indptr, n_rows, trim, vals);
2042
CUDA_CHECK_LAST_ERROR(find_top_k_per_row_kernel);
2143
}
2244

45+
static inline void launch_find_top_k_per_row_sorted(const float* data,
46+
const int* indptr,
47+
int n_rows, int trim,
48+
float* vals,
49+
cudaStream_t stream) {
50+
dim3 block(SORT_BLOCK_THREADS);
51+
dim3 grid(n_rows);
52+
find_top_k_per_row_sorted_kernel<SORT_BLOCK_THREADS, SORT_ITEMS_PER_THREAD>
53+
<<<grid, block, 0, stream>>>(data, indptr, n_rows, trim, vals);
54+
CUDA_CHECK_LAST_ERROR(find_top_k_per_row_sorted_kernel);
55+
}
56+
2357
static inline void launch_cut_smaller(int* indptr, int* index, float* data,
2458
float* vals, int n_rows,
2559
cudaStream_t stream) {
@@ -43,6 +77,20 @@ void register_bindings(nb::module_& m) {
4377
"data"_a, "indptr"_a, nb::kw_only(), "n_rows"_a, "trim"_a, "vals"_a,
4478
"stream"_a = 0);
4579

80+
m.def(
81+
"find_top_k_per_row_sorted",
82+
[](gpu_array_c<const float, Device> data,
83+
gpu_array_c<const int, Device> indptr, int n_rows, int trim,
84+
gpu_array_c<float, Device> vals, std::uintptr_t stream) {
85+
launch_find_top_k_per_row_sorted(data.data(), indptr.data(), n_rows,
86+
trim, vals.data(),
87+
(cudaStream_t)stream);
88+
},
89+
"data"_a, "indptr"_a, nb::kw_only(), "n_rows"_a, "trim"_a, "vals"_a,
90+
"stream"_a = 0);
91+
92+
m.def("sort_tile_size", []() { return SORT_TILE_SIZE; });
93+
4694
m.def(
4795
"cut_smaller",
4896
[](gpu_array_c<int, Device> indptr, gpu_array_c<int, Device> index,

src/rapids_singlecell/_cuda/bbknn/kernels_bbknn.cuh

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

3+
#include <cub/block/block_radix_sort.cuh>
34
#include <cuda_runtime.h>
5+
#include <math_constants.h>
46

57
__global__ void find_top_k_per_row_kernel(const float* __restrict__ data,
68
const int* __restrict__ indptr,
@@ -49,6 +51,59 @@ __global__ void find_top_k_per_row_kernel(const float* __restrict__ data,
4951
vals[row] = top_k[min_index];
5052
}
5153

54+
// Block-cooperative variant: one CUDA block per row, sorts the row with
55+
// BlockRadixSort, returns the `trim`-th largest as the cut value. Shared
56+
// memory is the CUB sort temp storage only, independent of `trim`, so it
57+
// scales to large `trim` values where the per-thread top-k kernel runs out
58+
// of shared memory. Requires every row to fit in BLOCK_THREADS *
59+
// ITEMS_PER_THREAD.
60+
template <int BLOCK_THREADS, int ITEMS_PER_THREAD>
61+
__global__ void find_top_k_per_row_sorted_kernel(const float* __restrict__ data,
62+
const int* __restrict__ indptr,
63+
const int n_rows,
64+
const int trim,
65+
float* __restrict__ vals) {
66+
int row = blockIdx.x;
67+
if (row >= n_rows) {
68+
return;
69+
}
70+
71+
int start = indptr[row];
72+
int end = indptr[row + 1];
73+
int length = end - start;
74+
75+
if (length <= trim) {
76+
if (threadIdx.x == 0) {
77+
vals[row] = 0.0f; // insufficient elements
78+
}
79+
return;
80+
}
81+
82+
using BlockRadixSort =
83+
cub::BlockRadixSort<float, BLOCK_THREADS, ITEMS_PER_THREAD>;
84+
__shared__ typename BlockRadixSort::TempStorage temp_storage;
85+
86+
float thread_keys[ITEMS_PER_THREAD];
87+
#pragma unroll
88+
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
89+
int idx = threadIdx.x * ITEMS_PER_THREAD + i;
90+
// Pad out-of-range with -inf so they sort to the bottom of a
91+
// descending sort and never appear among the trim largest.
92+
thread_keys[i] = (idx < length) ? data[start + idx] : -CUDART_INF_F;
93+
}
94+
95+
BlockRadixSort(temp_storage).SortDescending(thread_keys);
96+
97+
// After SortDescending with blocked arrangement, sorted index i lives at
98+
// thread (i / ITEMS_PER_THREAD), local slot (i % ITEMS_PER_THREAD).
99+
int target_idx = trim - 1;
100+
int target_thread = target_idx / ITEMS_PER_THREAD;
101+
int target_item = target_idx % ITEMS_PER_THREAD;
102+
if (threadIdx.x == target_thread) {
103+
vals[row] = thread_keys[target_item];
104+
}
105+
}
106+
52107
__global__ void cut_smaller_kernel(const int* __restrict__ indptr,
53108
const int* __restrict__ index,
54109
float* __restrict__ data,

src/rapids_singlecell/preprocessing/_neighbors/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,16 @@ def bbknn(
407407
knn_indices[:, col_range] = ind_to[sub_ind]
408408
knn_dist[:, col_range] = sub_dist
409409

410+
# Sort each row so neighbors are ordered closest-first across all batches.
411+
# fuzzy_simplicial_set uses the first non-zero distance per row as the
412+
# local-connectivity rho; unsorted input collapses sigma and weights.
413+
order = cp.argsort(knn_dist, axis=1)
414+
row_idx = cp.arange(n_obs)[:, None]
415+
knn_dist = knn_dist[row_idx, order]
416+
knn_indices = knn_indices[row_idx, order]
417+
410418
if trim is None:
411-
trim = 10 * neighbors_within_batch
419+
trim = 10 * total_neighbors
412420

413421
params = dict(
414422
n_neighbors=total_neighbors,

src/rapids_singlecell/preprocessing/_neighbors/_helper/__init__.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,30 +130,63 @@ def _fix_self_distances(knn_dist: cp.ndarray, metric: _Metrics) -> cp.ndarray:
130130
return knn_dist
131131

132132

133-
def _trimming(cnts: cp_sparse.csr_matrix, trim: int) -> cp_sparse.csr_matrix:
133+
# Empirically, the block-cooperative CUB sort kernel is faster for trim >= 100;
134+
# below this threshold the per-thread top-k kernel has less launch overhead.
135+
_TRIM_SORT_THRESHOLD = 100
136+
137+
138+
def _trimming(
139+
cnts: cp_sparse.csr_matrix,
140+
trim: int,
141+
*,
142+
kernel: str = "auto",
143+
) -> cp_sparse.csr_matrix:
134144
from rapids_singlecell._cuda._bbknn_cuda import (
135145
cut_smaller,
136146
find_top_k_per_row,
147+
find_top_k_per_row_sorted,
148+
sort_tile_size,
137149
)
138150

139151
n_rows = cnts.shape[0]
140152
vals_gpu = cp.zeros(n_rows, dtype=cp.float32)
153+
stream = cp.cuda.get_current_stream().ptr
154+
155+
if kernel == "auto":
156+
if trim >= _TRIM_SORT_THRESHOLD:
157+
max_row_nnz = int(cp.diff(cnts.indptr).max().get())
158+
kernel = "sorted" if max_row_nnz <= sort_tile_size() else "thread"
159+
else:
160+
kernel = "thread"
161+
162+
if kernel == "sorted":
163+
find_top_k_per_row_sorted(
164+
cnts.data,
165+
cnts.indptr,
166+
n_rows=n_rows,
167+
trim=trim,
168+
vals=vals_gpu,
169+
stream=stream,
170+
)
171+
elif kernel == "thread":
172+
find_top_k_per_row(
173+
cnts.data,
174+
cnts.indptr,
175+
n_rows=n_rows,
176+
trim=trim,
177+
vals=vals_gpu,
178+
stream=stream,
179+
)
180+
else:
181+
raise ValueError(f"Unknown trim kernel: {kernel!r}")
141182

142-
find_top_k_per_row(
143-
cnts.data,
144-
cnts.indptr,
145-
n_rows=cnts.shape[0],
146-
trim=trim,
147-
vals=vals_gpu,
148-
stream=cp.cuda.get_current_stream().ptr,
149-
)
150183
cut_smaller(
151184
cnts.indptr,
152185
cnts.indices,
153186
cnts.data,
154187
vals=vals_gpu,
155-
n_rows=cnts.shape[0],
156-
stream=cp.cuda.get_current_stream().ptr,
188+
n_rows=n_rows,
189+
stream=stream,
157190
)
158191
cnts.eliminate_zeros()
159192
return cnts

tests/test_neighbors.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,81 @@ def test_bbknn():
144144
assert counter / b_stop > 0.9
145145

146146

147-
def test_trimming():
147+
def test_bbknn_distances_sorted_per_row():
148+
# fuzzy_simplicial_set uses the first non-zero distance per row as rho;
149+
# unsorted per-batch columns break sigma estimation and collapse weights.
150+
adata = pbmc68k_reduced()
151+
bbknn(adata, n_pcs=15, batch_key="phase", algorithm="brute")
152+
dists = adata.obsp["distances"]
153+
for start, stop in itertools.pairwise(dists.indptr):
154+
row = dists.data[start:stop]
155+
assert np.all(np.diff(row) >= 0), "bbknn distance rows must be sorted ascending"
156+
157+
158+
def test_bbknn_connectivities_not_collapsed():
159+
# Regression: before the per-row sort fix, mean connectivity on this
160+
# dataset was ~0.85 with most weights pinned near 1.0. With sorted input
161+
# the distribution spreads out properly.
162+
adata = pbmc68k_reduced()
163+
bbknn(adata, n_pcs=15, batch_key="phase", algorithm="brute")
164+
weights = adata.obsp["connectivities"].data
165+
# Regression bounds chosen vs pre-fix behaviour (mean ~0.60, >0.99 frac ~0.28
166+
# on this dataset). Anything close to those values indicates sigma estimation
167+
# has broken again. Healthy distribution on the same dataset is mean ~0.50.
168+
assert weights.mean() < 0.7
169+
assert (weights > 0.99).mean() < 0.5
170+
171+
172+
def test_bbknn_trim_default_matches_upstream():
173+
# bbknn upstream defaults trim = 10 * total_neighbors
174+
# (= 10 * neighbors_within_batch * n_batches).
175+
adata = pbmc68k_reduced()
176+
n_batches = adata.obs["phase"].nunique()
177+
neighbors_within_batch = 3
178+
bbknn(
179+
adata,
180+
n_pcs=15,
181+
batch_key="phase",
182+
algorithm="brute",
183+
neighbors_within_batch=neighbors_within_batch,
184+
)
185+
assert (
186+
adata.uns["neighbors"]["params"]["trim"]
187+
== 10 * neighbors_within_batch * n_batches
188+
)
189+
190+
191+
@pytest.mark.parametrize("trim", [5, 240])
192+
def test_trimming(trim):
193+
# trim=5: typical case.
194+
# trim=240: exercises the kernel's adaptive block-size path. A static
195+
# BLOCK_SIZE=64 would request 60 KB of dynamic shared memory and fail to
196+
# launch (default per-block cap is ~48 KB).
197+
adata = pbmc68k_reduced()
198+
cnts_gpu = X_to_GPU(adata.obsp["connectivities"]).astype(np.float32)
199+
cnts_cpu = adata.obsp["connectivities"].astype(np.float32)
200+
201+
cnts_cpu = trimming_cpu(cnts_cpu, trim)
202+
cnts_gpu = trimming_gpu(cnts_gpu, trim)
203+
204+
cp.testing.assert_array_equal(cnts_cpu.data, cnts_gpu.data)
205+
cp.testing.assert_array_equal(cnts_cpu.indices, cnts_gpu.indices)
206+
cp.testing.assert_array_equal(cnts_cpu.indptr, cnts_gpu.indptr)
207+
208+
209+
@pytest.mark.parametrize("trim", [5, 50, 240])
210+
@pytest.mark.parametrize("kernel", ["thread", "sorted"])
211+
def test_trimming_kernels_agree(trim, kernel):
212+
# Both trim kernels must produce identical results to the CPU reference
213+
# (bbknn.matrix.trimming) on the same input. The "thread" kernel keeps a
214+
# per-thread top-k in shared memory; the "sorted" kernel does one block
215+
# per row with BlockRadixSort.
148216
adata = pbmc68k_reduced()
149217
cnts_gpu = X_to_GPU(adata.obsp["connectivities"]).astype(np.float32)
150218
cnts_cpu = adata.obsp["connectivities"].astype(np.float32)
151219

152-
cnts_cpu = trimming_cpu(cnts_cpu, 5)
153-
cnts_gpu = trimming_gpu(cnts_gpu, 5)
220+
cnts_cpu = trimming_cpu(cnts_cpu, trim)
221+
cnts_gpu = trimming_gpu(cnts_gpu, trim, kernel=kernel)
154222

155223
cp.testing.assert_array_equal(cnts_cpu.data, cnts_gpu.data)
156224
cp.testing.assert_array_equal(cnts_cpu.indices, cnts_gpu.indices)

0 commit comments

Comments
 (0)