Skip to content

Commit c653810

Browse files
selmanozleyenpre-commit-ci[bot]flying-sheep
authored
Fix for spatial_neighbors breaks when transform='spectral' (#1028)
* fix * attempt to add tests and add better documentation to the functions to describe what's happening * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add vibe coded tests prune later * add tests for spectral transform * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * complete the merge conflict issue * add fau and clean up code * make loops parallel * specify fast arrayutils dep * Add fast_array_utils to project dependencies * cache kernel * forgot to save file bf commit * Apply suggestion from @flying-sheep * remove unused imports (idk why this didn't fail on linter) * remove more unused imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move to float32 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add tests for spectral transform * add fau and clean up code * specify fast arrayutils dep * Add fast_array_utils to project dependencies * Apply suggestion from @flying-sheep * remove unused imports (idk why this didn't fail on linter) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * merge conflix fix on deps * format imports again * fix fau import --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Philipp A. <[email protected]>
1 parent f11e20f commit c653810

File tree

3 files changed

+142
-18
lines changed

3 files changed

+142
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies = [
4949
"dask[array]>=2021.2,<=2024.11.2",
5050
"dask-image>=0.5",
5151
"docrep>=0.3.1",
52+
"fast-array-utils",
5253
"fsspec>=2021.11",
5354
"imagecodecs>=2025.8.2,<2026",
5455
"matplotlib>=3.3",

src/squidpy/gr/_build.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import pandas as pd
1414
from anndata import AnnData
1515
from anndata.utils import make_index_unique
16-
from numba import njit
16+
from fast_array_utils import stats as fau_stats
17+
from numba import njit, prange
1718
from scanpy import logging as logg
1819
from scipy.sparse import (
1920
SparseEfficiencyWarning,
2021
block_diag,
22+
csr_array,
2123
csr_matrix,
2224
isspmatrix_csr,
2325
spmatrix,
@@ -28,9 +30,16 @@
2830
from sklearn.neighbors import NearestNeighbors
2931
from spatialdata import SpatialData
3032
from spatialdata._core.centroids import get_centroids
31-
from spatialdata._core.query.relational_query import get_element_instances, match_element_to_table
33+
from spatialdata._core.query.relational_query import (
34+
get_element_instances,
35+
match_element_to_table,
36+
)
3237
from spatialdata.models import get_table_keys
33-
from spatialdata.models.models import Labels2DModel, Labels3DModel, get_model
38+
from spatialdata.models.models import (
39+
Labels2DModel,
40+
Labels3DModel,
41+
get_model,
42+
)
3443

3544
from squidpy._constants._constants import CoordType, Transform
3645
from squidpy._constants._pkg_constants import Key
@@ -380,7 +389,7 @@ def _build_connectivity(
380389
if delaunay:
381390
tri = Delaunay(coords)
382391
indptr, indices = tri.vertex_neighbor_vertices
383-
Adj = csr_matrix((np.ones_like(indices, dtype=np.float64), indices, indptr), shape=(N, N))
392+
Adj = csr_matrix((np.ones_like(indices, dtype=np.float32), indices, indptr), shape=(N, N))
384393

385394
if return_distance:
386395
# fmt: off
@@ -415,7 +424,7 @@ def _build_connectivity(
415424
col_indices = np.concatenate(col_indices)
416425

417426
Adj = csr_matrix(
418-
(np.ones_like(row_indices, dtype=np.float64), (row_indices, col_indices)),
427+
(np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)),
419428
shape=(N, N),
420429
)
421430
if return_distance:
@@ -431,28 +440,67 @@ def _build_connectivity(
431440

432441

433442
@njit
434-
def outer(indices: NDArrayA, indptr: NDArrayA, degrees: NDArrayA) -> NDArrayA:
435-
res = np.empty_like(indices, dtype=np.float64)
436-
start = 0
437-
for i in range(len(indptr) - 1):
438-
ixs = indices[indptr[i] : indptr[i + 1]]
439-
res[start : start + len(ixs)] = degrees[i] * degrees[ixs]
440-
start += len(ixs)
443+
def _csr_bilateral_diag_scale_helper(
444+
mat: csr_array | csr_matrix,
445+
degrees: NDArrayA,
446+
) -> NDArrayA:
447+
"""
448+
Return an array F aligned with CSR non-zeros such that
449+
F[k] = d[i] * data[k] * d[j] for the k-th non-zero (i, j) in CSR order.
450+
451+
Parameters
452+
----------
453+
454+
data : array of float
455+
CSR `data` (non-zero values).
456+
indices : array of int
457+
CSR `indices` (column indices).
458+
indptr : array of int
459+
CSR `indptr` (row pointer).
460+
degrees : array of float, shape (n,)
461+
Diagonal scaling vector.
462+
463+
Returns
464+
-------
465+
array of float
466+
Length equals len(data). Entry-wise factors d_i * d_j * data[k]
467+
"""
468+
469+
res = np.empty_like(mat.data, dtype=np.float32)
470+
for i in prange(len(mat.indptr) - 1):
471+
ixs = mat.indices[mat.indptr[i] : mat.indptr[i + 1]]
472+
res[mat.indptr[i] : mat.indptr[i + 1]] = degrees[i] * degrees[ixs] * mat.data[mat.indptr[i] : mat.indptr[i + 1]]
441473

442474
return res
443475

444476

477+
def symmetric_normalize_csr(adj: spmatrix) -> csr_matrix:
478+
"""
479+
Return D^{-1/2} * A * D^{-1/2}, where D = diag(degrees(A)) and A = adj.
480+
481+
482+
Parameters
483+
----------
484+
adj : scipy.sparse.csr_matrix
485+
486+
Returns
487+
-------
488+
scipy.sparse.csr_matrix
489+
"""
490+
degrees = np.squeeze(np.array(np.sqrt(1.0 / fau_stats.sum(adj, axis=0))))
491+
if adj.shape[0] != len(degrees):
492+
raise ValueError("len(degrees) must equal number of rows of adj")
493+
res_data = _csr_bilateral_diag_scale_helper(adj, degrees)
494+
return csr_matrix((res_data, adj.indices, adj.indptr), shape=adj.shape)
495+
496+
445497
def _transform_a_spectral(a: spmatrix) -> spmatrix:
446498
if not isspmatrix_csr(a):
447499
a = a.tocsr()
448500
if not a.nnz:
449501
return a
450502

451-
degrees = np.squeeze(np.array(np.sqrt(1.0 / a.sum(axis=0))))
452-
a = a.multiply(outer(a.indices, a.indptr, degrees))
453-
a.eliminate_zeros()
454-
455-
return a
503+
return symmetric_normalize_csr(a)
456504

457505

458506
def _transform_a_cosine(a: spmatrix) -> spmatrix:

tests/graph/test_spatial_neighbors.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@ def _adata_concat(adata1, adata2):
4949
# TODO: add edge cases
5050
# TODO(giovp): test with reshuffling
5151
@pytest.mark.parametrize(("n_rings", "n_neigh", "sum_dist"), [(1, 6, 0), (2, 18, 30), (3, 36, 84)])
52-
def test_spatial_neighbors_visium(self, visium_adata: AnnData, n_rings: int, n_neigh: int, sum_dist: int):
52+
def test_spatial_neighbors_visium(
53+
self,
54+
visium_adata: AnnData,
55+
n_rings: int,
56+
n_neigh: int,
57+
sum_dist: int,
58+
):
5359
"""
5460
check correctness of neighborhoods for visium coordinates
5561
"""
@@ -355,3 +361,72 @@ def test_mask_graph(
355361
negative_mask=True,
356362
key_added=key_added,
357363
)
364+
365+
def test_spatial_neighbors_transform_mathematical_properties(self, non_visium_adata: AnnData):
366+
"""
367+
Test mathematical properties of each transform.
368+
"""
369+
# Test spectral transform properties
370+
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="spectral")
371+
adj_spectral = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()
372+
373+
# Spectral transform should be symmetric
374+
np.testing.assert_allclose(adj_spectral, adj_spectral.T, atol=1e-10)
375+
376+
# Spectral transform should have normalized rows (L2 norm <= 1)
377+
row_norms = np.sqrt(np.sum(adj_spectral**2, axis=1))
378+
np.testing.assert_array_less(row_norms, 1.0 + 1e-10)
379+
380+
# Test cosine transform properties
381+
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="cosine")
382+
adj_cosine = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()
383+
384+
# Cosine transform should be symmetric
385+
np.testing.assert_allclose(adj_cosine, adj_cosine.T, atol=1e-10)
386+
387+
# Cosine transform should have values in [-1, 1]
388+
np.testing.assert_array_less(-1.0 - 1e-10, adj_cosine)
389+
np.testing.assert_array_less(adj_cosine, 1.0 + 1e-10)
390+
391+
# Diagonal of cosine transform should be 1 (self-similarity)
392+
np.testing.assert_allclose(np.diag(adj_cosine), 1.0, atol=1e-10)
393+
394+
def test_spatial_neighbors_transform_edge_cases(self, non_visium_adata: AnnData):
395+
"""
396+
Test transforms with edge cases (empty graph, single node, etc.).
397+
"""
398+
# Test with a very small dataset
399+
small_adata = non_visium_adata[:5].copy() # Only 5 points
400+
401+
# Test all transforms with small dataset
402+
for transform in [None, "spectral", "cosine"]:
403+
spatial_neighbors(small_adata, delaunay=True, coord_type=None, transform=transform)
404+
assert Key.obsp.spatial_conn() in small_adata.obsp
405+
assert Key.obsp.spatial_dist() in small_adata.obsp
406+
407+
# Verify transform parameter is saved
408+
assert small_adata.uns[Key.uns.spatial_neighs()]["params"]["transform"] == transform
409+
410+
def test_spatial_neighbors_spectral_transform_properties(self, non_visium_adata: AnnData):
411+
"""
412+
Test that spectral transform preserves nonzero pattern and normalizes rows to sum to 1.
413+
"""
414+
# Apply spatial_neighbors without transform
415+
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform=None)
416+
adj_no_transform = non_visium_adata.obsp[Key.obsp.spatial_conn()].copy()
417+
418+
# Apply spatial_neighbors with spectral transform
419+
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="spectral")
420+
adj_spectral = non_visium_adata.obsp[Key.obsp.spatial_conn()]
421+
422+
# Check that nonzero patterns are identical
423+
np.testing.assert_array_equal(
424+
adj_no_transform.nonzero(),
425+
adj_spectral.nonzero(),
426+
err_msg="Spectral transform should preserve the sparsity pattern",
427+
)
428+
429+
w = np.linalg.eigvals(adj_spectral.toarray())
430+
# Eigenvalues should be in range [-1, 1]
431+
np.testing.assert_array_less(w, 1.0, err_msg="Eigenvalues should be <= 1")
432+
np.testing.assert_array_less(-1.0, w, err_msg="Eigenvalues should be >= -1")

0 commit comments

Comments
 (0)