@@ -144,13 +144,78 @@ def test_bbknn():
144144 assert counter / b_stop > 0.9
145145
146146
147- def test_trimming ():
147+ def test_bbknn_distances_sorted_per_row ():
148+ # fuzzy_simplicial_set uses the first non-zero distance per row as rho;
149+ # unsorted per-batch columns break sigma estimation and collapse weights.
150+ adata = pbmc68k_reduced ()
151+ bbknn (adata , n_pcs = 15 , batch_key = "phase" , algorithm = "brute" )
152+ dists = adata .obsp ["distances" ]
153+ for start , stop in itertools .pairwise (dists .indptr ):
154+ row = dists .data [start :stop ]
155+ assert np .all (np .diff (row ) >= 0 ), "bbknn distance rows must be sorted ascending"
156+
157+
158+ def test_bbknn_connectivities_not_collapsed ():
159+ # Regression: before the per-row sort fix, mean connectivity on this
160+ # dataset was ~0.85 with most weights pinned near 1.0. With sorted input
161+ # the distribution spreads out properly.
162+ adata = pbmc68k_reduced ()
163+ bbknn (adata , n_pcs = 15 , batch_key = "phase" , algorithm = "brute" )
164+ weights = adata .obsp ["connectivities" ].data
165+ assert weights .mean () < 0.7
166+ assert (weights > 0.99 ).mean () < 0.5
167+
168+
169+ def test_bbknn_trim_default_matches_upstream ():
170+ # bbknn upstream defaults trim = 10 * total_neighbors
171+ # (= 10 * neighbors_within_batch * n_batches).
172+ adata = pbmc68k_reduced ()
173+ n_batches = adata .obs ["phase" ].nunique ()
174+ neighbors_within_batch = 3
175+ bbknn (
176+ adata ,
177+ n_pcs = 15 ,
178+ batch_key = "phase" ,
179+ algorithm = "brute" ,
180+ neighbors_within_batch = neighbors_within_batch ,
181+ )
182+ assert (
183+ adata .uns ["neighbors" ]["params" ]["trim" ]
184+ == 10 * neighbors_within_batch * n_batches
185+ )
186+
187+
188+ @pytest .mark .parametrize ("trim" , [5 , 240 ])
189+ def test_trimming (trim ):
190+ # trim=5: typical case.
191+ # trim=240: exercises the kernel's adaptive block-size path. A static
192+ # BLOCK_SIZE=64 would request 60 KB of dynamic shared memory and fail to
193+ # launch (default per-block cap is ~48 KB).
194+ adata = pbmc68k_reduced ()
195+ cnts_gpu = X_to_GPU (adata .obsp ["connectivities" ]).astype (np .float32 )
196+ cnts_cpu = adata .obsp ["connectivities" ].astype (np .float32 )
197+
198+ cnts_cpu = trimming_cpu (cnts_cpu , trim )
199+ cnts_gpu = trimming_gpu (cnts_gpu , trim )
200+
201+ cp .testing .assert_array_equal (cnts_cpu .data , cnts_gpu .data )
202+ cp .testing .assert_array_equal (cnts_cpu .indices , cnts_gpu .indices )
203+ cp .testing .assert_array_equal (cnts_cpu .indptr , cnts_gpu .indptr )
204+
205+
206+ @pytest .mark .parametrize ("trim" , [5 , 50 , 240 ])
207+ @pytest .mark .parametrize ("kernel" , ["thread" , "sorted" ])
208+ def test_trimming_kernels_agree (trim , kernel ):
209+ # Both trim kernels must produce identical results to the CPU reference
210+ # (bbknn.matrix.trimming) on the same input. The "thread" kernel keeps a
211+ # per-thread top-k in shared memory; the "sorted" kernel does one block
212+ # per row with BlockRadixSort.
148213 adata = pbmc68k_reduced ()
149214 cnts_gpu = X_to_GPU (adata .obsp ["connectivities" ]).astype (np .float32 )
150215 cnts_cpu = adata .obsp ["connectivities" ].astype (np .float32 )
151216
152- cnts_cpu = trimming_cpu (cnts_cpu , 5 )
153- cnts_gpu = trimming_gpu (cnts_gpu , 5 )
217+ cnts_cpu = trimming_cpu (cnts_cpu , trim )
218+ cnts_gpu = trimming_gpu (cnts_gpu , trim , kernel = kernel )
154219
155220 cp .testing .assert_array_equal (cnts_cpu .data , cnts_gpu .data )
156221 cp .testing .assert_array_equal (cnts_cpu .indices , cnts_gpu .indices )
0 commit comments