-
Notifications
You must be signed in to change notification settings - Fork 673
Labels
Triage 🩺This issue needs to be triaged by a maintainerThis issue needs to be triaged by a maintainer
Description
Please make sure these conditions are met
- I have checked that this issue has not already been reported.
- I have confirmed this bug exists on the latest version of scanpy.
- (optional) I have confirmed this bug exists on the main branch of scanpy.
What happened?
Supporting custom neighbor transformers is a powerful use case that scanpy has recently enabled as part of sc.pp.neighbors
. However, since the downstream processing of neighbors assumes that each cell will have >= k neighbors, adaptive transformers that apply e.g. a distance threshold on neighbors do not work. This prevents the implementation of multiple published algorithms and of SNN, etc.
Would be happy to help enable this use case/fix the bug.
Cheers,
Malte
Minimal code sample
# %%
import numpy as np
from numpy.typing import NDArray
from scipy.sparse import csr_matrix
from sklearn.neighbors import KNeighborsTransformer
import scanpy as sc
import matplotlib.pyplot as plt
def threshold_neighbors(
distances: NDArray,
indices: NDArray,
distance_threshold: float,
shape: tuple[int, int],
) -> csr_matrix:
mask = distances < distance_threshold
row_indices = np.repeat(indices[:, 0], mask.sum(axis=1))
col_indices = indices[mask]
return csr_matrix((distances[mask], (row_indices, col_indices)), shape=shape)
class AdaptiveKNeighborsTransformer(KNeighborsTransformer):
def __init__(
self,
*,
n_neighbors=5,
distance_threshold=2,
**kwargs,
):
kwargs["n_neighbors"] = n_neighbors
super().__init__(**kwargs)
self.distance_threshold = distance_threshold
self.n_neighbors = n_neighbors
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
import inspect
frame = inspect.currentframe()
# Prevent directly calling the method
try:
locals = frame.f_back.f_locals
if locals.get("self", None) is not self:
raise RuntimeError("This method should not be called directly. Use the `transform` method instead.")
finally:
del frame
return super().kneighbors(X, n_neighbors, return_distance)
def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
raise NotImplementedError("This method is not implemented. Use the `transform` method instead.")
def transform(self, X):
add_one = self.mode == "distance"
nn_dists, nn_indices = self.kneighbors(X, n_neighbors=self.n_neighbors + add_one)
return threshold_neighbors(
nn_dists,
nn_indices,
self.distance_threshold,
(X.shape[0], X.shape[0]),
)
# %%
adata = sc.datasets.pbmc3k_processed()
# %%
knn = AdaptiveKNeighborsTransformer(n_neighbors=20, metric="cosine", distance_threshold=0.5)
neighbors = knn.fit_transform(adata.obsm["X_pca"])
# %%
neighbor_counts = [np.sum(neighbors[i].toarray() > 0) for i in range(neighbors.shape[0])]
print(neighbor_counts)
# %%
plt.hist(neighbor_counts, bins=np.arange(0, max(neighbor_counts) + 1, 1))
# %%
sc.pp.neighbors(
adata,
n_neighbors=15,
use_rep="X_pca",
key_added="connectivities_scanpy",
transformer=AdaptiveKNeighborsTransformer(n_neighbors=20, metric="cosine", distance_threshold=0.5),
)
Error output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[8], line 1
----> 1 sc.pp.neighbors(
2 adata,
3 n_neighbors=15,
4 use_rep="X_pca",
5 key_added="connectivities_scanpy",
6 transformer=AdaptiveKNeighborsTransformer(n_neighbors=20, metric="cosine", distance_threshold=0.5),
7 )
File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/__init__.py:194, in neighbors(adata, n_neighbors, n_pcs, use_rep, knn, method, transformer, metric, metric_kwds, random_state, key_added, copy)
192 adata._init_as_actual(adata.copy())
193 neighbors = Neighbors(adata)
--> 194 neighbors.compute_neighbors(
195 n_neighbors,
196 n_pcs=n_pcs,
197 use_rep=use_rep,
198 knn=knn,
199 method=method,
200 transformer=transformer,
201 metric=metric,
202 metric_kwds=metric_kwds,
203 random_state=random_state,
204 )
206 if key_added is None:
207 key_added = "neighbors"
File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/__init__.py:588, in Neighbors.compute_neighbors(self, n_neighbors, n_pcs, use_rep, knn, method, transformer, metric, metric_kwds, random_state)
586 X = _choose_representation(self._adata, use_rep=use_rep, n_pcs=n_pcs)
587 self._distances = transformer.fit_transform(X)
--> 588 knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix(
589 self._distances, n_neighbors
590 )
591 if shortcut:
592 # self._distances is a sparse matrix with a diag of 1, fix that
593 self._distances[np.diag_indices_from(self.distances)] = 0
File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/_common.py:86, in _get_indices_distances_from_sparse_matrix(D, n_neighbors)
84 indices, distances = shortcut
85 else:
---> 86 indices, distances = _ind_dist_slow(D, n_neighbors)
88 # handle RAPIDS style indices_distances lacking the self-column
89 if not _has_self_column(indices, distances):
File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/_common.py:121, in _ind_dist_slow(D, n_neighbors)
117 distances[i, 1:] = D[i][
118 neighbors[0][sorted_indices], neighbors[1][sorted_indices]
119 ]
120 else:
--> 121 indices[i, 1:] = neighbors[1]
122 distances[i, 1:] = D[i][neighbors]
123 return indices, distances
ValueError: could not broadcast input array from shape (7,) into shape (19,)
Versions
| Package | Version |
| ------------ | ------- |
| anndata | 0.12.2 |
| numpy | 2.2.6 |
| scipy | 1.14.1 |
| scikit-learn | 1.7.2 |
| scanpy | 1.11.4 |
| matplotlib | 3.10.6 |
| Dependency | Version |
| ------------------ | ----------- |
| pure_eval | 0.2.3 |
| matplotlib-inline | 0.1.7 |
| igraph | 0.11.9 |
| psutil | 7.0.0 |
| python-dateutil | 2.9.0.post0 |
| tornado | 6.5.2 |
| ipython | 9.5.0 |
| cycler | 0.12.1 |
| PyYAML | 6.0.2 |
| pyzmq | 27.1.0 |
| crc32c | 2.7.1 |
| platformdirs | 4.4.0 |
| natsort | 8.4.0 |
| asttokens | 3.0.0 |
| numcodecs | 0.16.2 |
| packaging | 25.0 |
| stack-data | 0.6.3 |
| numba | 0.61.2 |
| pytz | 2025.2 |
| xarray | 2025.9.0 |
| joblib | 1.5.2 |
| h5py | 3.14.0 |
| traitlets | 5.14.3 |
| Pygments | 2.19.2 |
| pyparsing | 3.2.4 |
| executing | 2.2.1 |
| pandas | 2.3.2 |
| appnope | 0.1.4 |
| texttable | 1.7.0 |
| jedi | 0.19.2 |
| pillow | 11.3.0 |
| typing_extensions | 4.15.0 |
| wcwidth | 0.2.13 |
| session-info2 | 0.2.2 |
| debugpy | 1.8.16 |
| parso | 0.8.5 |
| threadpoolctl | 3.6.0 |
| legacy-api-wrap | 1.4.1 |
| zarr | 3.1.2 |
| llvmlite | 0.44.0 |
| charset-normalizer | 3.4.3 |
| donfig | 0.8.1.post1 |
| prompt_toolkit | 3.0.52 |
| jupyter_core | 5.8.1 |
| six | 1.17.0 |
| jupyter_client | 8.6.3 |
| kiwisolver | 1.4.9 |
| leidenalg | 0.10.2 |
| ipykernel | 6.30.1 |
| comm | 0.2.3 |
| decorator | 5.2.1 |
| Component | Info |
| --------- | ----------------------------------------------------------------------- |
| Python | 3.13.7 (main, Aug 14 2025, 11:12:11) [Clang 17.0.0 (clang-1700.0.13.3)] |
| OS | macOS-15.6.1-arm64-arm-64bit-Mach-O |
| CPU | 12 logical CPU cores, arm |
| GPU | No GPU found |
| Updated | 2025-09-19 13:28 |
Metadata
Metadata
Assignees
Labels
Triage 🩺This issue needs to be triaged by a maintainerThis issue needs to be triaged by a maintainer