Skip to content

Custom neighbor transformers with variable neighbor count <= k fail #3806

@maltekuehl

Description

@maltekuehl

Please make sure these conditions are met

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest version of scanpy.
  • (optional) I have confirmed this bug exists on the main branch of scanpy.

What happened?

Supporting custom neighbor transformers is a powerful use case that scanpy has recently enabled as part of sc.pp.neighbors. However, since the downstream processing of neighbors assumes that each cell will have >= k neighbors, adaptive transformers that apply e.g. a distance threshold on neighbors do not work. This prevents the implementation of multiple published algorithms and of SNN, etc.

Would be happy to help enable this use case/fix the bug.

Cheers,
Malte

Minimal code sample

# %%
import numpy as np
from numpy.typing import NDArray
from scipy.sparse import csr_matrix
from sklearn.neighbors import KNeighborsTransformer
import scanpy as sc
import matplotlib.pyplot as plt


def threshold_neighbors(
    distances: NDArray,
    indices: NDArray,
    distance_threshold: float,
    shape: tuple[int, int],
) -> csr_matrix:
    mask = distances < distance_threshold

    row_indices = np.repeat(indices[:, 0], mask.sum(axis=1))
    col_indices = indices[mask]

    return csr_matrix((distances[mask], (row_indices, col_indices)), shape=shape)


class AdaptiveKNeighborsTransformer(KNeighborsTransformer):
    def __init__(
        self,
        *,
        n_neighbors=5,
        distance_threshold=2,
        **kwargs,
    ):
        kwargs["n_neighbors"] = n_neighbors
        super().__init__(**kwargs)
        self.distance_threshold = distance_threshold
        self.n_neighbors = n_neighbors

    def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
        import inspect

        frame = inspect.currentframe()

        # Prevent directly calling the method
        try:
            locals = frame.f_back.f_locals
            if locals.get("self", None) is not self:
                raise RuntimeError("This method should not be called directly. Use the `transform` method instead.")
        finally:
            del frame

        return super().kneighbors(X, n_neighbors, return_distance)

    def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
        raise NotImplementedError("This method is not implemented. Use the `transform` method instead.")

    def transform(self, X):
        add_one = self.mode == "distance"
        nn_dists, nn_indices = self.kneighbors(X, n_neighbors=self.n_neighbors + add_one)

        return threshold_neighbors(
            nn_dists,
            nn_indices,
            self.distance_threshold,
            (X.shape[0], X.shape[0]),
        )


# %%
adata = sc.datasets.pbmc3k_processed()

# %%
knn = AdaptiveKNeighborsTransformer(n_neighbors=20, metric="cosine", distance_threshold=0.5)
neighbors = knn.fit_transform(adata.obsm["X_pca"])

# %%
neighbor_counts = [np.sum(neighbors[i].toarray() > 0) for i in range(neighbors.shape[0])]
print(neighbor_counts)

# %%
plt.hist(neighbor_counts, bins=np.arange(0, max(neighbor_counts) + 1, 1))

# %%
sc.pp.neighbors(
    adata,
    n_neighbors=15,
    use_rep="X_pca",
    key_added="connectivities_scanpy",
    transformer=AdaptiveKNeighborsTransformer(n_neighbors=20, metric="cosine", distance_threshold=0.5),
)

Error output

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 1
----> 1 sc.pp.neighbors(
      2     adata,
      3     n_neighbors=15,
      4     use_rep="X_pca",
      5     key_added="connectivities_scanpy",
      6     transformer=AdaptiveKNeighborsTransformer(n_neighbors=20, metric="cosine", distance_threshold=0.5),
      7 )

File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/__init__.py:194, in neighbors(adata, n_neighbors, n_pcs, use_rep, knn, method, transformer, metric, metric_kwds, random_state, key_added, copy)
    192     adata._init_as_actual(adata.copy())
    193 neighbors = Neighbors(adata)
--> 194 neighbors.compute_neighbors(
    195     n_neighbors,
    196     n_pcs=n_pcs,
    197     use_rep=use_rep,
    198     knn=knn,
    199     method=method,
    200     transformer=transformer,
    201     metric=metric,
    202     metric_kwds=metric_kwds,
    203     random_state=random_state,
    204 )
    206 if key_added is None:
    207     key_added = "neighbors"

File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/__init__.py:588, in Neighbors.compute_neighbors(self, n_neighbors, n_pcs, use_rep, knn, method, transformer, metric, metric_kwds, random_state)
    586 X = _choose_representation(self._adata, use_rep=use_rep, n_pcs=n_pcs)
    587 self._distances = transformer.fit_transform(X)
--> 588 knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix(
    589     self._distances, n_neighbors
    590 )
    591 if shortcut:
    592     # self._distances is a sparse matrix with a diag of 1, fix that
    593     self._distances[np.diag_indices_from(self.distances)] = 0

File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/_common.py:86, in _get_indices_distances_from_sparse_matrix(D, n_neighbors)
     84     indices, distances = shortcut
     85 else:
---> 86     indices, distances = _ind_dist_slow(D, n_neighbors)
     88 # handle RAPIDS style indices_distances lacking the self-column
     89 if not _has_self_column(indices, distances):

File ~/Documents/code/clusterkit-dev/.venv/lib/python3.13/site-packages/scanpy/neighbors/_common.py:121, in _ind_dist_slow(D, n_neighbors)
    117         distances[i, 1:] = D[i][
    118             neighbors[0][sorted_indices], neighbors[1][sorted_indices]
    119         ]
    120     else:
--> 121         indices[i, 1:] = neighbors[1]
    122         distances[i, 1:] = D[i][neighbors]
    123 return indices, distances

ValueError: could not broadcast input array from shape (7,) into shape (19,)

Versions

| Package      | Version |
| ------------ | ------- |
| anndata      | 0.12.2  |
| numpy        | 2.2.6   |
| scipy        | 1.14.1  |
| scikit-learn | 1.7.2   |
| scanpy       | 1.11.4  |
| matplotlib   | 3.10.6  |

| Dependency         | Version     |
| ------------------ | ----------- |
| pure_eval          | 0.2.3       |
| matplotlib-inline  | 0.1.7       |
| igraph             | 0.11.9      |
| psutil             | 7.0.0       |
| python-dateutil    | 2.9.0.post0 |
| tornado            | 6.5.2       |
| ipython            | 9.5.0       |
| cycler             | 0.12.1      |
| PyYAML             | 6.0.2       |
| pyzmq              | 27.1.0      |
| crc32c             | 2.7.1       |
| platformdirs       | 4.4.0       |
| natsort            | 8.4.0       |
| asttokens          | 3.0.0       |
| numcodecs          | 0.16.2      |
| packaging          | 25.0        |
| stack-data         | 0.6.3       |
| numba              | 0.61.2      |
| pytz               | 2025.2      |
| xarray             | 2025.9.0    |
| joblib             | 1.5.2       |
| h5py               | 3.14.0      |
| traitlets          | 5.14.3      |
| Pygments           | 2.19.2      |
| pyparsing          | 3.2.4       |
| executing          | 2.2.1       |
| pandas             | 2.3.2       |
| appnope            | 0.1.4       |
| texttable          | 1.7.0       |
| jedi               | 0.19.2      |
| pillow             | 11.3.0      |
| typing_extensions  | 4.15.0      |
| wcwidth            | 0.2.13      |
| session-info2      | 0.2.2       |
| debugpy            | 1.8.16      |
| parso              | 0.8.5       |
| threadpoolctl      | 3.6.0       |
| legacy-api-wrap    | 1.4.1       |
| zarr               | 3.1.2       |
| llvmlite           | 0.44.0      |
| charset-normalizer | 3.4.3       |
| donfig             | 0.8.1.post1 |
| prompt_toolkit     | 3.0.52      |
| jupyter_core       | 5.8.1       |
| six                | 1.17.0      |
| jupyter_client     | 8.6.3       |
| kiwisolver         | 1.4.9       |
| leidenalg          | 0.10.2      |
| ipykernel          | 6.30.1      |
| comm               | 0.2.3       |
| decorator          | 5.2.1       |

| Component | Info                                                                    |
| --------- | ----------------------------------------------------------------------- |
| Python    | 3.13.7 (main, Aug 14 2025, 11:12:11) [Clang 17.0.0 (clang-1700.0.13.3)] |
| OS        | macOS-15.6.1-arm64-arm-64bit-Mach-O                                     |
| CPU       | 12 logical CPU cores, arm                                               |
| GPU       | No GPU found                                                            |
| Updated   | 2025-09-19 13:28                                                        |

Metadata

Metadata

Assignees

No one assigned

    Labels

    Triage 🩺This issue needs to be triaged by a maintainer

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions