Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5f61dd5
init
selmanozleyen Mar 2, 2026
40dad3e
debloat spatial_neighbours
selmanozleyen Mar 2, 2026
c81969f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2026
7eef650
Merge branch 'main' into chore/sdata-access
selmanozleyen Mar 11, 2026
8a7cc98
use a func instead of decorator
selmanozleyen Mar 11, 2026
365dad7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
2751464
update niche
selmanozleyen Mar 11, 2026
cde62d5
build
selmanozleyen Mar 11, 2026
137f2c7
remove dupcode
selmanozleyen Mar 11, 2026
b358304
fix build
selmanozleyen Mar 11, 2026
da420f8
Merge branch 'main' into chore/sdata-access
selmanozleyen Mar 11, 2026
1608430
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2026
8c7c9e3
add tests
selmanozleyen Mar 11, 2026
dbdc166
update conf
selmanozleyen Mar 16, 2026
6039e68
Merge branch 'main' into chore/sdata-access
selmanozleyen Mar 16, 2026
f53ae10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
53ba958
remove dead code
selmanozleyen Mar 16, 2026
466124d
Merge branch 'main' into chore/sdata-access
selmanozleyen Apr 2, 2026
6e60795
Merge branch 'main' into chore/sdata-access
selmanozleyen Apr 7, 2026
df6c923
set table default on comnpute_nicche
selmanozleyen Apr 7, 2026
d2439f7
table_key should fit the other signatures
selmanozleyen Apr 7, 2026
d185f8f
remove module
selmanozleyen Apr 9, 2026
67f34fc
undo conf
selmanozleyen Apr 9, 2026
0bdbf8d
rename extract+adata
selmanozleyen Apr 9, 2026
d409c76
Merge branch 'main' into chore/sdata-access
selmanozleyen Apr 13, 2026
d99bbdc
Merge branch 'main' into chore/sdata-access
selmanozleyen May 12, 2026
c0d696f
Merge branch 'main' into chore/sdata-access
selmanozleyen May 12, 2026
cc67e44
Merge branch 'main' into chore/sdata-access
selmanozleyen May 15, 2026
0518d51
expose table key public API
selmanozleyen May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/squidpy/_constants/_pkg_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,18 @@ def _sort_haystack(
return haystack

class obsp:
@staticmethod
def _spatial_key(value: str | None, suffix: str) -> str:
if value is None:
return f"{Key.obsm.spatial}_{suffix}"
if value.endswith(f"_{suffix}"):
return value
return f"{value}_{suffix}"

@classmethod
def spatial_dist(cls, value: str | None = None) -> str:
return f"{Key.obsm.spatial}_distances" if value is None else f"{value}_distances"
return cls._spatial_key(value, "distances")

@classmethod
def spatial_conn(cls, value: str | None = None) -> str:
return f"{Key.obsm.spatial}_connectivities" if value is None else f"{value}_connectivities"
return cls._spatial_key(value, "connectivities")
5 changes: 5 additions & 0 deletions src/squidpy/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def decorator2(obj: Any) -> Any:
_adata = """\
adata
Annotated data object."""
_table_key = """\
table_key
Key in :attr:`spatialdata.SpatialData.tables` where the table is stored. Required when ``adata`` is a
:class:`spatialdata.SpatialData` object and ignored otherwise."""
_img_container = """\
img
High-resolution image."""
Expand Down Expand Up @@ -361,6 +365,7 @@ def decorator2(obj: Any) -> Any:

d = DocstringProcessor(
adata=_adata,
table_key=_table_key,
img_container=_img_container,
copy=_copy,
copy_cont=_copy_cont,
Expand Down
114 changes: 65 additions & 49 deletions src/squidpy/gr/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_assert_categorical_obs,
_assert_spatial_basis,
_save_data,
extract_adata_if_sdata,
)

__all__ = ["spatial_neighbors"]
Expand Down Expand Up @@ -140,54 +141,14 @@ def spatial_neighbors(
- :attr:`anndata.AnnData.obsp` ``['{{key_added}}_distances']`` - the spatial distances.
- :attr:`anndata.AnnData.uns` ``['{{key_added}}']`` - :class:`dict` containing parameters.
"""
if isinstance(adata, SpatialData):
assert elements_to_coordinate_systems is not None, (
"Since `adata` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`."
)
assert table_key is not None, (
"Since `adata` is a :class:`spatialdata.SpatialData`, `table_key` must not be `None`."
)
elements, table = match_element_to_table(adata, list(elements_to_coordinate_systems), table_key)
assert table.obs_names.equals(adata.tables[table_key].obs_names), (
"The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary."
)
regions, region_key, instance_key = get_table_keys(adata.tables[table_key])
regions = [regions] if isinstance(regions, str) else regions
ordered_regions_in_table = adata.tables[table_key].obs[region_key].unique()

# TODO: remove this after https://github.com/scverse/spatialdata/issues/614
remove_centroids = {}
elem_instances = []
for e in regions:
schema = get_model(elements[e])
element_instances = get_element_instances(elements[e]).to_series()
if np.isin(0, element_instances.values) and (schema in (Labels2DModel, Labels3DModel)):
element_instances = element_instances.drop(index=0)
remove_centroids[e] = True
else:
remove_centroids[e] = False
elem_instances.append(element_instances)

element_instances = pd.concat(elem_instances)
if (not np.all(element_instances.values == adata.tables[table_key].obs[instance_key].values)) or (
not np.all(ordered_regions_in_table == regions)
):
raise ValueError(
"The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary."
)
centroids = []
for region_ in ordered_regions_in_table:
cs = elements_to_coordinate_systems[region_]
centroid = get_centroids(adata[region_], coordinate_system=cs)[["x", "y"]].compute()

# TODO: remove this after https://github.com/scverse/spatialdata/issues/614
if remove_centroids[region_]:
centroid = centroid[1:].copy()
centroids.append(centroid)

adata.tables[table_key].obsm[spatial_key] = np.concatenate(centroids)
adata = adata.tables[table_key]
library_key = region_key
adata, library_key = _resolve_sdata(
adata=adata,
elements_to_coordinate_systems=elements_to_coordinate_systems,
table_key=table_key,
spatial_key=spatial_key,
library_key=library_key,
)

assert_positive(n_rings, name="n_rings")
assert_positive(n_neighs, name="n_neighs")
Expand Down Expand Up @@ -434,6 +395,62 @@ def _build_connectivity(
return Adj


def _resolve_sdata(
adata: AnnData | SpatialData,
elements_to_coordinate_systems: dict[str, str] | None,
table_key: str | None = None,
spatial_key: str = Key.obsm.spatial,
library_key: str | None = None,
) -> tuple[AnnData, str | None]:
if not isinstance(adata, SpatialData):
return adata, library_key

assert elements_to_coordinate_systems is not None, (
"Since `adata` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`."
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"adata is sdata" is weird? Maybe it should be resolve_data/input?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but the error message would still have to be the same since the public signature takes adata, otherwise the error message would be confusing

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is done in many places lets make an issue about it and move on for here. Ideally we would call it data since it's positional argument anyway

)
table = extract_adata_if_sdata(adata, table_key=table_key)
elements, matched_table = match_element_to_table(adata, list(elements_to_coordinate_systems), table_key)
assert matched_table.obs_names.equals(table.obs_names), (
"The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary."
)
regions, region_key, instance_key = get_table_keys(table)
regions = [regions] if isinstance(regions, str) else regions
ordered_regions_in_table = table.obs[region_key].unique()

# TODO: remove this after https://github.com/scverse/spatialdata/issues/614
remove_centroids = {}
elem_instances = []
for e in regions:
schema = get_model(elements[e])
element_instances = get_element_instances(elements[e]).to_series()
if np.isin(0, element_instances.values) and (schema in (Labels2DModel, Labels3DModel)):
element_instances = element_instances.drop(index=0)
remove_centroids[e] = True
else:
remove_centroids[e] = False
elem_instances.append(element_instances)

element_instances = pd.concat(elem_instances)
if (not np.all(element_instances.values == table.obs[instance_key].values)) or (
not np.all(ordered_regions_in_table == regions)
):
raise ValueError(
"The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary."
)
centroids = []
for region_ in ordered_regions_in_table:
cs = elements_to_coordinate_systems[region_]
centroid = get_centroids(adata[region_], coordinate_system=cs)[["x", "y"]].compute()

# TODO: remove this after https://github.com/scverse/spatialdata/issues/614
if remove_centroids[region_]:
centroid = centroid[1:].copy()
centroids.append(centroid)

table.obsm[spatial_key] = np.concatenate(centroids)
Comment thread
selmanozleyen marked this conversation as resolved.
return table, region_key


@njit
def _csr_bilateral_diag_scale_helper(
mat: csr_array | csr_matrix,
Expand Down Expand Up @@ -557,8 +574,7 @@ def mask_graph(
if not isinstance(polygon_mask, Polygon | MultiPolygon):
raise ValueError(f"`polygon_mask` should be of type `Polygon` or `MultiPolygon`, got {type(polygon_mask)}")

# get elements
table = sdata.tables[table_key]
table = extract_adata_if_sdata(sdata, table_key=table_key)
coords = table.obsm[spatial_key]
Adj = table.obsp[conns_key]
Dst = table.obsp[dists_key]
Expand Down
7 changes: 5 additions & 2 deletions src/squidpy/gr/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_assert_categorical_obs,
_genesymbols,
_save_data,
extract_adata_if_sdata,
)

__all__ = ["ligrec", "PermutationTest"]
Expand Down Expand Up @@ -641,6 +642,8 @@ def ligrec(
copy: bool = False,
key_added: str | None = None,
gene_symbols: str | None = None,
*,
table_key: str | None = None,
**kwargs: Any,
) -> Mapping[str, pd.DataFrame] | None:
"""
Expand All @@ -649,6 +652,7 @@ def ligrec(
Parameters
----------
%(PT.parameters)s
%(table_key)s
%(PT_prepare_full.parameters)s
%(PT_test.parameters)s
gene_symbols
Expand All @@ -658,8 +662,7 @@ def ligrec(
-------
%(ligrec_test_returns)s
""" # noqa: D400
if isinstance(adata, SpatialData):
adata = adata.table
adata = extract_adata_if_sdata(adata, table_key=table_key)
with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False):
return ( # type: ignore[no-any-return]
PermutationTest(adata, use_raw=use_raw)
Expand Down
19 changes: 13 additions & 6 deletions src/squidpy/gr/_nhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_assert_connectivity_key,
_save_data,
_shuffle_group,
extract_adata_if_sdata,
)

__all__ = ["nhood_enrichment", "centrality_scores", "interaction_matrix"]
Expand Down Expand Up @@ -146,13 +147,16 @@ def nhood_enrichment(
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = True,
*,
table_key: str | None = None,
) -> NhoodEnrichmentResult | None:
"""
Compute neighborhood enrichment by permutation test.

Parameters
----------
%(adata)s
%(table_key)s
%(cluster_key)s
%(library_key)s
%(conn_key)s
Expand All @@ -171,8 +175,7 @@ def nhood_enrichment(
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score.
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count.
"""
if isinstance(adata, SpatialData):
adata = adata.table
adata = extract_adata_if_sdata(adata, table_key=table_key)
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
_assert_categorical_obs(adata, cluster_key)
_assert_connectivity_key(adata, connectivity_key)
Expand Down Expand Up @@ -239,6 +242,8 @@ def centrality_scores(
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = False,
*,
table_key: str | None = None,
) -> pd.DataFrame | None:
"""
Compute centrality scores per cluster or cell type.
Expand All @@ -248,6 +253,7 @@ def centrality_scores(
Parameters
----------
%(adata)s
%(table_key)s
%(cluster_key)s
score
Centrality measures as described in :mod:`networkx.algorithms.centrality` :cite:`networkx`.
Expand All @@ -268,8 +274,7 @@ def centrality_scores(
- :attr:`anndata.AnnData.uns` ``['{{cluster_key}}_centrality_scores']`` - the centrality scores,
as mentioned above.
"""
if isinstance(adata, SpatialData):
adata = adata.table
adata = extract_adata_if_sdata(adata, table_key=table_key)
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
_assert_categorical_obs(adata, cluster_key)
_assert_connectivity_key(adata, connectivity_key)
Expand Down Expand Up @@ -333,13 +338,16 @@ def interaction_matrix(
normalized: bool = False,
copy: bool = False,
weights: bool = False,
*,
table_key: str | None = None,
) -> NDArrayA | None:
"""
Compute interaction matrix for clusters.

Parameters
----------
%(adata)s
%(table_key)s
%(cluster_key)s
%(conn_key)s
normalized
Expand All @@ -356,8 +364,7 @@ def interaction_matrix(

- :attr:`anndata.AnnData.uns` ``['{cluster_key}_interactions']`` - the interaction matrix.
"""
if isinstance(adata, SpatialData):
adata = adata.table
adata = extract_adata_if_sdata(adata, table_key=table_key)
connectivity_key = Key.obsp.spatial_conn(connectivity_key)
_assert_categorical_obs(adata, cluster_key)
_assert_connectivity_key(adata, connectivity_key)
Expand Down
33 changes: 13 additions & 20 deletions src/squidpy/gr/_niche.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from squidpy._constants._constants import NicheDefinitions
from squidpy._docs import d, inject_docs
from squidpy._validators import assert_isinstance, assert_key_in_adata, assert_one_of
from squidpy.gr._utils import extract_adata_if_sdata

__all__ = ["calculate_niche"]

Expand All @@ -31,7 +32,6 @@ def calculate_niche(
data: AnnData | SpatialData,
flavor: Literal["neighborhood", "utag", "cellcharter", "spatialleiden"],
library_key: str | None = None,
table_key: str | None = None,
mask: pd.core.series.Series = None,
groups: str | None = None,
n_neighbors: int | None = None,
Expand All @@ -51,7 +51,9 @@ def calculate_niche(
use_weights: bool | tuple[bool, bool] = True,
use_rep: str | None = None,
inplace: bool = True,
) -> AnnData:
*,
table_key: str | None = None,
) -> AnnData | None:
"""
Calculate niches (spatial clusters) based on a user-defined method in 'flavor'.
The resulting niche labels with be stored in 'adata.obs'.
Expand All @@ -68,8 +70,7 @@ def calculate_niche(
%(library_key)s
If provided, niches will be calculated separately for each unique value in this column.
Each niche will be prefixed with the library identifier.
table_key
Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed.
%(table_key)s
mask
Boolean array to filter cells which won't get assigned to a niche.
Note that if you want to exclude these cells during neighborhood calculation already, you should subset your AnnData table before running 'sq.gr.spatial_neigbors'.
Expand Down Expand Up @@ -169,12 +170,8 @@ def calculate_niche(
if resolutions is None:
resolutions = [0.5]

if isinstance(data, SpatialData):
orig_adata = data.tables[table_key]
adata = orig_adata.copy()
else:
orig_adata = data
adata = data.copy()
orig_adata = extract_adata_if_sdata(data, table_key=table_key)
adata = orig_adata.copy()

assert_key_in_adata(
adata,
Expand Down Expand Up @@ -843,18 +840,14 @@ def _validate_niche_args(

assert_one_of(flavor, ["neighborhood", "utag", "cellcharter", "spatialleiden"], name="flavor")

if isinstance(data, SpatialData) and table_key is None:
raise TypeError("missing required keyword-only argument: 'table_key'")

if library_key is not None:
assert_isinstance(library_key, str, name="library_key")
if isinstance(data, AnnData):
if library_key not in data.obs.columns:
raise ValueError(f"'library_key' must be a column in 'adata.obs', got {library_key}")
elif isinstance(data, SpatialData):
if table_key is None:
raise ValueError("'table_key' is required when 'data' is a SpatialData object")
if table_key not in data.tables:
raise ValueError(f"'table_key' must be a valid table key in 'data', got {table_key}")
if library_key not in data.tables[table_key].obs.columns:
raise ValueError(f"'library_key' must be a column in 'adata.obs', got {library_key}")
adata = extract_adata_if_sdata(data, table_key=table_key)
if library_key not in adata.obs.columns:
raise ValueError(f"'library_key' must be a column in 'adata.obs', got {library_key}")

if n_neighbors is not None:
assert_isinstance(n_neighbors, int, name="n_neighbors")
Expand Down
Loading
Loading