Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Operations on `SpatialData` objects.
.. autofunction:: join_spatialelement_table
.. autofunction:: match_element_to_table
.. autofunction:: match_table_to_element
.. autofunction:: match_sdata_to_table
.. autofunction:: concatenate
.. autofunction:: transform
.. autofunction:: rasterize
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"join_spatialelement_table",
"match_element_to_table",
"match_table_to_element",
"match_sdata_to_table",
"SpatialData",
"get_extent",
"get_centroids",
Expand Down Expand Up @@ -72,6 +73,7 @@
get_values,
join_spatialelement_table,
match_element_to_table,
match_sdata_to_table,
match_table_to_element,
)
from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query
Expand Down
39 changes: 39 additions & 0 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,45 @@ def match_element_to_table(
return element_dict, table


def match_sdata_to_table(
sdata: SpatialData,
table_name: str,
table: AnnData | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
) -> SpatialData:
"""
Filter the elements of a SpatialData object to match only the rows present in the table.

Parameters
----------
sdata
SpatialData object containing all the elements and tables.
table
The table to join with the spatial elements. Has precedence over `table_name`.
table_name
The name of the table to join with the SpatialData object if `table` is not provided. If table is provided,
`table_name` is used to name the table in the returned `SpatialData` object.
how
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".

"""
if table is None:
table = sdata[table_name]
_, region_key, instance_key = get_table_keys(table)
annotated_regions = SpatialData.get_annotated_regions(table)
filtered_elements, filtered_table = join_spatialelement_table(
sdata, spatial_element_names=annotated_regions, table=table, how=how
)
filtered_table = TableModel.parse(
filtered_table,
region=annotated_regions,
region_key=region_key,
instance_key=instance_key,
overwrite_metadata=True,
)
return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table})


@dataclass
class _ValueOrigin:
origin: str
Expand Down
7 changes: 4 additions & 3 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def from_elements_dict(
return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs)

@staticmethod
def get_annotated_regions(table: AnnData) -> str | list[str]:
def get_annotated_regions(table: AnnData) -> list[str]:
"""
Get the regions annotated by a table.

Expand All @@ -275,8 +275,9 @@ def get_annotated_regions(table: AnnData) -> str | list[str]:
-------
The annotated regions.
"""
regions, _, _ = get_table_keys(table)
return regions
from spatialdata.models.models import _get_region_metadata_from_region_key_column

return _get_region_metadata_from_region_key_column(table)

@staticmethod
def get_region_key_column(table: AnnData) -> pd.Series:
Expand Down
51 changes: 39 additions & 12 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ def parse(
region: str | list[str] | None = None,
region_key: str | None = None,
instance_key: str | None = None,
overwrite_metadata: bool = False,
) -> AnnData:
"""
Parse the :class:`anndata.AnnData` to be compatible with the model.
Expand All @@ -1078,6 +1079,8 @@ def parse(
Key in `adata.obs` that specifies the region.
instance_key
Key in `adata.obs` that specifies the instance.
overwrite_metadata
If `True`, the `region`, `region_key` and `instance_key` metadata will be overwritten.

Returns
-------
Expand All @@ -1087,31 +1090,38 @@ def parse(
# either all live in adata.uns or all be passed in as argument
n_args = sum([region is not None, region_key is not None, instance_key is not None])
if n_args == 0:
return adata
if n_args > 0:
if cls.ATTRS_KEY in adata.uns:
raise ValueError(
f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as"
f"as argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set."
)
elif cls.ATTRS_KEY in adata.uns:
if cls.ATTRS_KEY not in adata.uns:
# table not annotating any element
return adata
attr = adata.uns[cls.ATTRS_KEY]
region = attr[cls.REGION_KEY]
region_key = attr[cls.REGION_KEY_KEY]
instance_key = attr[cls.INSTANCE_KEY]
elif n_args > 0 and not overwrite_metadata and cls.ATTRS_KEY in adata.uns:
raise ValueError(
f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as"
f" argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set."
)

if cls.ATTRS_KEY not in adata.uns:
adata.uns[cls.ATTRS_KEY] = {}

if region is None:
raise ValueError(f"`{cls.REGION_KEY}` must be provided.")
if region_key is None:
raise ValueError(f"`{cls.REGION_KEY_KEY}` must be provided.")
if instance_key is None:
raise ValueError("`instance_key` must be provided.")

if isinstance(region, np.ndarray):
region = region.tolist()
if region is None:
raise ValueError(f"`{cls.REGION_KEY}` must be provided.")
region_: list[str] = region if isinstance(region, list) else [region]
if not adata.obs[region_key].isin(region_).all():
raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.")

if instance_key is None:
raise ValueError("`instance_key` must be provided.")
adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region
adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key
adata.uns[cls.ATTRS_KEY][cls.INSTANCE_KEY] = instance_key

# note! this is an expensive check and therefore we skip it during validation
# https://github.com/scverse/spatialdata/issues/715
Expand Down Expand Up @@ -1214,3 +1224,20 @@ def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]:
raise ValueError(
"No spatialdata_attrs key found in table.uns, therefore, no table keys found. Please parse the table."
)


def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]:
_, region_key, instance_key = get_table_keys(table)
region_key_column = table.obs[region_key]
if not isinstance(region_key_column.dtype, CategoricalDtype):
warnings.warn(
f"The region key column `{region_key}` is not of type `pd.Categorical`. Consider casting it to "
f"improve performance.",
UserWarning,
stacklevel=2,
)
annotated_regions = region_key_column.unique().tolist()
else:
annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist()
assert isinstance(annotated_regions, list)
return annotated_regions
144 changes: 144 additions & 0 deletions tests/core/query/test_relational_query_match_sdata_to_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pytest

from spatialdata import SpatialData, concatenate, match_sdata_to_table
from spatialdata.datasets import blobs_annotating_element


def _make_test_data() -> SpatialData:
sdata1 = blobs_annotating_element("blobs_polygons")
sdata2 = blobs_annotating_element("blobs_polygons")
sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True)
sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0]))
return sdata


# constructing the example data; let's use a global variable as we can reuse the same object on most tests
# without having to recreate it
sdata = _make_test_data()


def test_match_sdata_to_table_filter_specific_instances():
"""
Filter to keep only specific instances. Note that it works even when the table annotates multiple elements.
"""
matched = match_sdata_to_table(
sdata,
table=sdata["table"][sdata["table"].obs.instance_id.isin([1, 2])],
table_name="table",
)
assert len(matched["table"]) == 4
assert "blobs_polygons-sdata1" in matched
assert "blobs_polygons-sdata2" in matched


def test_match_sdata_to_table_filter_specific_instances_element():
"""
Filter to keep only specific instances, in a specific element.
"""
matched = match_sdata_to_table(
sdata,
table=sdata["table"][
sdata["table"].obs.instance_id.isin([1, 2]) & (sdata["table"].obs.region == "blobs_polygons-sdata1")
],
table_name="table",
)
assert len(matched["table"]) == 2
assert "blobs_polygons-sdata1" in matched
assert "blobs_polygons-sdata2" not in matched


def test_match_sdata_to_table_filter_by_threshold():
"""
Filter by a threshold on a value column, in a specific element.
"""
matched = match_sdata_to_table(
sdata,
table=sdata["table"][sdata["table"].obs.query('value < 5 and region == "blobs_polygons-sdata1"').index],
table_name="table",
)
assert len(matched["table"]) == 5
assert "blobs_polygons-sdata1" in matched
assert "blobs_polygons-sdata2" not in matched


def test_match_sdata_to_table_subset_certain_obs():
"""
Subset to certain obs (we could also subset to certain var or layer).
"""
matched = match_sdata_to_table(
sdata,
table=sdata["table"][[0, 1, 2, 3]],
table_name="table",
)
assert len(matched["table"]) == 4
assert "blobs_polygons-sdata1" in matched
assert "blobs_polygons-sdata2" not in matched


def test_match_sdata_to_table_shapes_and_points():
"""
The function works both for shapes (examples above) and points.
Changes the target of the table to labels.
"""
sdata = _make_test_data()
sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "points"))
sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
sdata.set_table_annotates_spatialelement(
table_name="table",
region=["blobs_points-sdata1", "blobs_points-sdata2"],
region_key="region",
instance_key="instance_id",
)

matched = match_sdata_to_table(
sdata,
table=sdata["table"],
table_name="table",
)

assert len(matched["table"]) == 10
assert "blobs_points-sdata1" in matched
assert "blobs_points-sdata2" in matched
assert "blobs_polygons-sdata1" not in matched


def test_match_sdata_to_table_match_labels_error():
"""
match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the
join.
"""
sdata = _make_test_data()
sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels"))
sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
sdata.set_table_annotates_spatialelement(
table_name="table",
region=["blobs_labels-sdata1", "blobs_labels-sdata2"],
region_key="region",
instance_key="instance_id",
)

with pytest.warns(
UserWarning,
match="Element type `labels` not supported for 'right' join. Skipping ",
):
matched = match_sdata_to_table(
sdata,
table=sdata["table"],
table_name="table",
)

assert len(matched["table"]) == 10
assert "blobs_labels-sdata1" in matched
assert "blobs_labels-sdata2" in matched
assert "blobs_points-sdata1" not in matched


def test_match_sdata_to_table_no_table_argument():
"""
If no table argument is passed, the table_name argument will be used to match the table.
"""
matched = match_sdata_to_table(sdata=sdata, table_name="table")

assert len(matched["table"]) == 10
assert "blobs_polygons-sdata1" in matched
assert "blobs_polygons-sdata2" in matched
43 changes: 42 additions & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from spatialdata._core.validation import ValidationError
from spatialdata._types import ArrayLike
from spatialdata.config import LARGE_CHUNK_THRESHOLD_BYTES
from spatialdata.models import get_table_keys
from spatialdata.models._utils import (
force_2d,
points_dask_dataframe_to_geopandas,
Expand Down Expand Up @@ -377,6 +378,46 @@ def test_table_model(
assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY]
assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region

# error when trying to parse a table by specifying region, region_key, instance_key, but these keys are
# already set
with pytest.raises(ValueError, match=" has already been set"):
_ = TableModel.parse(adata, region=region, region_key=region_key, instance_key="A")

# error when region is missing
with pytest.raises(ValueError, match="`region` must be provided"):
_ = TableModel.parse(adata, region_key=region_key, instance_key="A", overwrite_metadata=True)

# error when region_key is missing
with pytest.raises(ValueError, match="`region_key` must be provided"):
_ = TableModel.parse(adata, region=region, instance_key="A", overwrite_metadata=True)

# error when instance_key is missing
with pytest.raises(ValueError, match="`instance_key` must be provided"):
_ = TableModel.parse(adata, region=region, region_key=region_key, overwrite_metadata=True)

# we try to overwrite, but the values in the `region_key` column do not match the expected `region` values
with pytest.raises(ValueError, match="values do not match with `region` values"):
_ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True)

# we correctly overwrite; here we check that the metadata is updated
region_, region_key_, instance_key_ = get_table_keys(table)
assert region_ == region
assert region_key_ == region_key
assert instance_key_ == "A"

# let's fix the region_key column
table.obs["B"] = ["element"] * len(table)
_ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True)

region_, region_key_, instance_key_ = get_table_keys(table)
assert region_ == "element"
assert region_key_ == "B"
assert instance_key_ == "C"

# we can parse a table when no metadata is present (i.e. the table does not annotate any element)
del table.uns[TableModel.ATTRS_KEY]
_ = TableModel.parse(table)

@pytest.mark.parametrize(
"name",
[
Expand Down Expand Up @@ -423,7 +464,7 @@ def test_table_instance_key_values_not_unique(self, model: TableModel, region: s
ValueError,
match=re.escape("Instance key column for region(s) `sample_1, sample_2`"),
):
model.parse(adata, region=region, region_key=region_key, instance_key="A")
model.parse(adata, region=region, region_key=region_key, instance_key="A", overwrite_metadata=True)

@pytest.mark.parametrize(
"key",
Expand Down
Loading