diff --git a/docs/api/operations.md b/docs/api/operations.md index f82331c79..937b8dbca 100644 --- a/docs/api/operations.md +++ b/docs/api/operations.md @@ -14,6 +14,7 @@ Operations on `SpatialData` objects. .. autofunction:: join_spatialelement_table .. autofunction:: match_element_to_table .. autofunction:: match_table_to_element +.. autofunction:: match_sdata_to_table .. autofunction:: concatenate .. autofunction:: transform .. autofunction:: rasterize diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 77a8411e7..9ddfea32d 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -40,6 +40,7 @@ "join_spatialelement_table", "match_element_to_table", "match_table_to_element", + "match_sdata_to_table", "SpatialData", "get_extent", "get_centroids", @@ -72,6 +73,7 @@ get_values, join_spatialelement_table, match_element_to_table, + match_sdata_to_table, match_table_to_element, ) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 237241e45..b84d43c1b 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -784,6 +784,45 @@ def match_element_to_table( return element_dict, table +def match_sdata_to_table( + sdata: SpatialData, + table_name: str, + table: AnnData | None = None, + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", +) -> SpatialData: + """ + Filter the elements of a SpatialData object to match only the rows present in the table. + + Parameters + ---------- + sdata + SpatialData object containing all the elements and tables. + table + The table to join with the spatial elements. Has precedence over `table_name`. + table_name + The name of the table to join with the SpatialData object if `table` is not provided. If table is provided, + `table_name` is used to name the table in the returned `SpatialData` object. + how + The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + + """ + if table is None: + table = sdata[table_name] + _, region_key, instance_key = get_table_keys(table) + annotated_regions = SpatialData.get_annotated_regions(table) + filtered_elements, filtered_table = join_spatialelement_table( + sdata, spatial_element_names=annotated_regions, table=table, how=how + ) + filtered_table = TableModel.parse( + filtered_table, + region=annotated_regions, + region_key=region_key, + instance_key=instance_key, + overwrite_metadata=True, + ) + return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table}) + + @dataclass class _ValueOrigin: origin: str diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index c4cce2142..f011d08f8 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -262,7 +262,7 @@ def from_elements_dict( return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs) @staticmethod - def get_annotated_regions(table: AnnData) -> str | list[str]: + def get_annotated_regions(table: AnnData) -> list[str]: """ Get the regions annotated by a table. @@ -275,8 +275,9 @@ def get_annotated_regions(table: AnnData) -> str | list[str]: ------- The annotated regions. """ - regions, _, _ = get_table_keys(table) - return regions + from spatialdata.models.models import _get_region_metadata_from_region_key_column + + return _get_region_metadata_from_region_key_column(table) @staticmethod def get_region_key_column(table: AnnData) -> pd.Series: diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index ad16d23a7..7aeb0b2c0 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -1064,6 +1064,7 @@ def parse( region: str | list[str] | None = None, region_key: str | None = None, instance_key: str | None = None, + overwrite_metadata: bool = False, ) -> AnnData: """ Parse the :class:`anndata.AnnData` to be compatible with the model. @@ -1078,6 +1079,8 @@ def parse( Key in `adata.obs` that specifies the region. instance_key Key in `adata.obs` that specifies the instance. + overwrite_metadata + If `True`, the `region`, `region_key` and `instance_key` metadata will be overwritten. Returns ------- @@ -1087,31 +1090,38 @@ def parse( # either all live in adata.uns or all be passed in as argument n_args = sum([region is not None, region_key is not None, instance_key is not None]) if n_args == 0: - return adata - if n_args > 0: - if cls.ATTRS_KEY in adata.uns: - raise ValueError( - f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" - f"as argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set." - ) - elif cls.ATTRS_KEY in adata.uns: + if cls.ATTRS_KEY not in adata.uns: + # table not annotating any element + return adata attr = adata.uns[cls.ATTRS_KEY] region = attr[cls.REGION_KEY] region_key = attr[cls.REGION_KEY_KEY] instance_key = attr[cls.INSTANCE_KEY] + elif n_args > 0 and not overwrite_metadata and cls.ATTRS_KEY in adata.uns: + raise ValueError( + f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" + f" argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set." + ) + + if cls.ATTRS_KEY not in adata.uns: + adata.uns[cls.ATTRS_KEY] = {} + if region is None: + raise ValueError(f"`{cls.REGION_KEY}` must be provided.") if region_key is None: raise ValueError(f"`{cls.REGION_KEY_KEY}` must be provided.") + if instance_key is None: + raise ValueError("`instance_key` must be provided.") + if isinstance(region, np.ndarray): region = region.tolist() - if region is None: - raise ValueError(f"`{cls.REGION_KEY}` must be provided.") region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") - if instance_key is None: - raise ValueError("`instance_key` must be provided.") + adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region + adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key + adata.uns[cls.ATTRS_KEY][cls.INSTANCE_KEY] = instance_key # note! this is an expensive check and therefore we skip it during validation # https://github.com/scverse/spatialdata/issues/715 @@ -1214,3 +1224,20 @@ def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]: raise ValueError( "No spatialdata_attrs key found in table.uns, therefore, no table keys found. Please parse the table." ) + + +def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]: + _, region_key, instance_key = get_table_keys(table) + region_key_column = table.obs[region_key] + if not isinstance(region_key_column.dtype, CategoricalDtype): + warnings.warn( + f"The region key column `{region_key}` is not of type `pd.Categorical`. Consider casting it to " + f"improve performance.", + UserWarning, + stacklevel=2, + ) + annotated_regions = region_key_column.unique().tolist() + else: + annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist() + assert isinstance(annotated_regions, list) + return annotated_regions diff --git a/tests/core/query/test_relational_query_match_sdata_to_table.py b/tests/core/query/test_relational_query_match_sdata_to_table.py new file mode 100644 index 000000000..6d4fcadf3 --- /dev/null +++ b/tests/core/query/test_relational_query_match_sdata_to_table.py @@ -0,0 +1,144 @@ +import pytest + +from spatialdata import SpatialData, concatenate, match_sdata_to_table +from spatialdata.datasets import blobs_annotating_element + + +def _make_test_data() -> SpatialData: + sdata1 = blobs_annotating_element("blobs_polygons") + sdata2 = blobs_annotating_element("blobs_polygons") + sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True) + sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0])) + return sdata + + +# constructing the example data; let's use a global variable as we can reuse the same object on most tests +# without having to recreate it +sdata = _make_test_data() + + +def test_match_sdata_to_table_filter_specific_instances(): + """ + Filter to keep only specific instances. Note that it works even when the table annotates multiple elements. + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][sdata["table"].obs.instance_id.isin([1, 2])], + table_name="table", + ) + assert len(matched["table"]) == 4 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" in matched + + +def test_match_sdata_to_table_filter_specific_instances_element(): + """ + Filter to keep only specific instances, in a specific element. + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][ + sdata["table"].obs.instance_id.isin([1, 2]) & (sdata["table"].obs.region == "blobs_polygons-sdata1") + ], + table_name="table", + ) + assert len(matched["table"]) == 2 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" not in matched + + +def test_match_sdata_to_table_filter_by_threshold(): + """ + Filter by a threshold on a value column, in a specific element. + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][sdata["table"].obs.query('value < 5 and region == "blobs_polygons-sdata1"').index], + table_name="table", + ) + assert len(matched["table"]) == 5 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" not in matched + + +def test_match_sdata_to_table_subset_certain_obs(): + """ + Subset to certain obs (we could also subset to certain var or layer). + """ + matched = match_sdata_to_table( + sdata, + table=sdata["table"][[0, 1, 2, 3]], + table_name="table", + ) + assert len(matched["table"]) == 4 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" not in matched + + +def test_match_sdata_to_table_shapes_and_points(): + """ + The function works both for shapes (examples above) and points. + Changes the target of the table to labels. + """ + sdata = _make_test_data() + sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "points")) + sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") + sdata.set_table_annotates_spatialelement( + table_name="table", + region=["blobs_points-sdata1", "blobs_points-sdata2"], + region_key="region", + instance_key="instance_id", + ) + + matched = match_sdata_to_table( + sdata, + table=sdata["table"], + table_name="table", + ) + + assert len(matched["table"]) == 10 + assert "blobs_points-sdata1" in matched + assert "blobs_points-sdata2" in matched + assert "blobs_polygons-sdata1" not in matched + + +def test_match_sdata_to_table_match_labels_error(): + """ + match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the + join. + """ + sdata = _make_test_data() + sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels")) + sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category") + sdata.set_table_annotates_spatialelement( + table_name="table", + region=["blobs_labels-sdata1", "blobs_labels-sdata2"], + region_key="region", + instance_key="instance_id", + ) + + with pytest.warns( + UserWarning, + match="Element type `labels` not supported for 'right' join. Skipping ", + ): + matched = match_sdata_to_table( + sdata, + table=sdata["table"], + table_name="table", + ) + + assert len(matched["table"]) == 10 + assert "blobs_labels-sdata1" in matched + assert "blobs_labels-sdata2" in matched + assert "blobs_points-sdata1" not in matched + + +def test_match_sdata_to_table_no_table_argument(): + """ + If no table argument is passed, the table_name argument will be used to match the table. + """ + matched = match_sdata_to_table(sdata=sdata, table_name="table") + + assert len(matched["table"]) == 10 + assert "blobs_polygons-sdata1" in matched + assert "blobs_polygons-sdata2" in matched diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fe2bfea46..0aec3a327 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -29,6 +29,7 @@ from spatialdata._core.validation import ValidationError from spatialdata._types import ArrayLike from spatialdata.config import LARGE_CHUNK_THRESHOLD_BYTES +from spatialdata.models import get_table_keys from spatialdata.models._utils import ( force_2d, points_dask_dataframe_to_geopandas, @@ -377,6 +378,46 @@ def test_table_model( assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY] assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region + # error when trying to parse a table by specifying region, region_key, instance_key, but these keys are + # already set + with pytest.raises(ValueError, match=" has already been set"): + _ = TableModel.parse(adata, region=region, region_key=region_key, instance_key="A") + + # error when region is missing + with pytest.raises(ValueError, match="`region` must be provided"): + _ = TableModel.parse(adata, region_key=region_key, instance_key="A", overwrite_metadata=True) + + # error when region_key is missing + with pytest.raises(ValueError, match="`region_key` must be provided"): + _ = TableModel.parse(adata, region=region, instance_key="A", overwrite_metadata=True) + + # error when instance_key is missing + with pytest.raises(ValueError, match="`instance_key` must be provided"): + _ = TableModel.parse(adata, region=region, region_key=region_key, overwrite_metadata=True) + + # we try to overwrite, but the values in the `region_key` column do not match the expected `region` values + with pytest.raises(ValueError, match="values do not match with `region` values"): + _ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True) + + # we correctly overwrite; here we check that the metadata is updated + region_, region_key_, instance_key_ = get_table_keys(table) + assert region_ == region + assert region_key_ == region_key + assert instance_key_ == "A" + + # let's fix the region_key column + table.obs["B"] = ["element"] * len(table) + _ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True) + + region_, region_key_, instance_key_ = get_table_keys(table) + assert region_ == "element" + assert region_key_ == "B" + assert instance_key_ == "C" + + # we can parse a table when no metadata is present (i.e. the table does not annotate any element) + del table.uns[TableModel.ATTRS_KEY] + _ = TableModel.parse(table) + @pytest.mark.parametrize( "name", [ @@ -423,7 +464,7 @@ def test_table_instance_key_values_not_unique(self, model: TableModel, region: s ValueError, match=re.escape("Instance key column for region(s) `sample_1, sample_2`"), ): - model.parse(adata, region=region, region_key=region_key, instance_key="A") + model.parse(adata, region=region, region_key=region_key, instance_key="A", overwrite_metadata=True) @pytest.mark.parametrize( "key",