Skip to content

Commit 835a035

Browse files
selmanozleyenpre-commit-ci[bot]LucaMarconatoclaude
authored
Support modifying/filtering labels elements via join operations (#946)
* init * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mypy linterrors * update the location and the design * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update docs * make coverage 100/100 because why not * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed type annotation * dont compute eagerly use. delete other instance key for consistency * update the tests and make sure we use match_element_to_table * wip rewrite tests using existing APIs * test passing without using subset_sdata_by_table_mask() * Remove _filter_by_instance_ids and _get_scale_factors; refactor tests to use existing API - Remove _get_scale_factors (duplicated logic already in transformations/_utils.py) - Remove _filter_by_instance_ids and subset_sdata_by_table_mask (superseded by match_sdata_to_table / filter_by_table_query) - Parametrize test_subset_sdata_by_table_mask over both API functions - Replace test_filter_2d_labels_by_instance_ids with test_filter_out_instances, parametrized over both API functions and element types (2D / multiscale labels) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add filter_label_pixels flag to match_sdata_to_table and filter_by_table_query Threads a filter_label_pixels: bool = False parameter through the full join stack (filter_by_table_query → match_sdata_to_table → join_spatialelement_table → _call_join → _right/_inner_join_spatialelement_table). When True, label pixels for removed instances are zeroed via a new _filter_labels_element helper (handles both DataArray and multiscale DataTree). When False (default), the existing warning is preserved but now also hints at the new flag. Tests no longer need manual _set_instance_ids_in_labels_to_zero calls or warnings suppression. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Change filter_label_pixels default to None; False silences the warning - None (default): warn that label pixels are not filtered, hint at the flag - True: filter label pixels (set removed instance pixels to zero) - False: skip silently, no warning Updated docstrings in join_spatialelement_table, match_sdata_to_table, and filter_by_table_query to document all three states. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Move and consolidate label-filtering tests into test_relational_query_match_sdata_to_table - Replace test_match_sdata_to_table_match_labels_error with test_filter_out_instances: parametrized over both API functions and element types; tests all three filter_label_pixels states (None→warn, False→nullcontext noop, True→pixels filtered) - Add test_subset_sdata_by_table_mask for mixed-element subsetting - Delete test_relational_query_subset_sdata_by_table_mask.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add 3D labels guard to _filter_labels_element; fix annsel list predicate in tests - Raise NotImplementedError in _filter_labels_element when element is Labels3DModel - Add test_filter_out_instances_3d_labels_not_supported parametrized over both API functions - Use an.col().is_in() instead of == [list] in 3D test (narwhals does not support nested literals) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato <m.lucalmer@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 5c241f9 commit 835a035

2 files changed

Lines changed: 236 additions & 45 deletions

File tree

src/spatialdata/_core/query/relational_query.py

Lines changed: 123 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import dask.array as da
1212
import numpy as np
1313
import pandas as pd
14+
import xarray as xr
1415
from anndata import AnnData
1516
from annsel.core.typing import Predicates
1617
from dask.dataframe import DataFrame as DaskDataFrame
@@ -311,7 +312,10 @@ def _get_masked_element(
311312

312313

313314
def _right_exclusive_join_spatialelement_table(
314-
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
315+
element_dict: dict[str, dict[str, Any]],
316+
table: AnnData,
317+
match_rows: Literal["left", "no", "right"],
318+
filter_label_pixels: bool | None = None,
315319
) -> tuple[dict[str, Any], AnnData | None]:
316320
regions, region_column_name, instance_key = get_table_keys(table)
317321
if isinstance(regions, str):
@@ -349,7 +353,10 @@ def _right_exclusive_join_spatialelement_table(
349353

350354

351355
def _right_join_spatialelement_table(
352-
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
356+
element_dict: dict[str, dict[str, Any]],
357+
table: AnnData,
358+
match_rows: Literal["left", "no", "right"],
359+
filter_label_pixels: bool | None = None,
353360
) -> tuple[dict[str, Any], AnnData]:
354361
if match_rows == "left":
355362
warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2)
@@ -365,11 +372,18 @@ def _right_join_spatialelement_table(
365372
if element_type in ["points", "shapes"]:
366373
element_indices = element.index
367374
else:
368-
warnings.warn(
369-
f"Element type `labels` not supported for 'right' join. Skipping `{name}`",
370-
UserWarning,
371-
stacklevel=2,
372-
)
375+
if filter_label_pixels is True:
376+
element_dict[element_type][name] = _filter_labels_element(
377+
element, table_instance_key_column.tolist()
378+
)
379+
elif filter_label_pixels is None:
380+
warnings.warn(
381+
f"Element type `labels` not supported for 'right' join, pixels are not filtered;"
382+
f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence"
383+
f" this warning. Skipping `{name}`",
384+
UserWarning,
385+
stacklevel=2,
386+
)
373387
continue
374388

375389
masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows)
@@ -383,7 +397,10 @@ def _right_join_spatialelement_table(
383397

384398

385399
def _inner_join_spatialelement_table(
386-
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
400+
element_dict: dict[str, dict[str, Any]],
401+
table: AnnData,
402+
match_rows: Literal["left", "no", "right"],
403+
filter_label_pixels: bool | None = None,
387404
) -> tuple[dict[str, Any], AnnData]:
388405
regions, region_column_name, instance_key = get_table_keys(table)
389406
if isinstance(regions, str):
@@ -399,11 +416,18 @@ def _inner_join_spatialelement_table(
399416
if element_type in ["points", "shapes"]:
400417
element_indices = element.index
401418
else:
402-
warnings.warn(
403-
f"Element type `labels` not supported for 'inner' join. Skipping `{name}`",
404-
UserWarning,
405-
stacklevel=2,
406-
)
419+
if filter_label_pixels is True:
420+
element_dict[element_type][name] = _filter_labels_element(
421+
element, table_instance_key_column.tolist()
422+
)
423+
elif filter_label_pixels is None:
424+
warnings.warn(
425+
f"Element type `labels` not supported for 'inner' join, pixels are not filtered;"
426+
f" pass `filter_label_pixels=True` to filter or `filter_label_pixels=False` to silence"
427+
f" this warning. Skipping `{name}`",
428+
UserWarning,
429+
stacklevel=2,
430+
)
407431
continue
408432

409433
masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows)
@@ -429,7 +453,10 @@ def _inner_join_spatialelement_table(
429453

430454

431455
def _left_exclusive_join_spatialelement_table(
432-
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
456+
element_dict: dict[str, dict[str, Any]],
457+
table: AnnData,
458+
match_rows: Literal["left", "no", "right"],
459+
filter_label_pixels: bool | None = None,
433460
) -> tuple[dict[str, Any], AnnData | None]:
434461
regions, region_column_name, instance_key = get_table_keys(table)
435462
if isinstance(regions, str):
@@ -462,7 +489,10 @@ def _left_exclusive_join_spatialelement_table(
462489

463490

464491
def _left_join_spatialelement_table(
465-
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
492+
element_dict: dict[str, dict[str, Any]],
493+
table: AnnData,
494+
match_rows: Literal["left", "no", "right"],
495+
filter_label_pixels: bool | None = None,
466496
) -> tuple[dict[str, Any], AnnData]:
467497
if match_rows == "right":
468498
warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2)
@@ -586,6 +616,7 @@ def join_spatialelement_table(
586616
table: AnnData | None = None,
587617
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left",
588618
match_rows: Literal["no", "left", "right"] = "no",
619+
filter_label_pixels: bool | None = None,
589620
) -> tuple[dict[str, Any], AnnData]:
590621
"""
591622
Join SpatialElement(s) and table together in SQL like manner.
@@ -629,6 +660,11 @@ def join_spatialelement_table(
629660
match_rows
630661
Whether to match the indices of the element and table and if so how. If ``'left'``, element_indices take
631662
priority and if ``'right'`` table instance ids take priority.
663+
filter_label_pixels
664+
Controls pixel-level filtering of label elements for ``'right'`` and ``'inner'`` joins.
665+
If ``True``, pixels whose instance id is not present in the table are set to zero.
666+
If ``None`` (default), label elements are returned unfiltered and a warning is issued.
667+
If ``False``, label elements are returned unfiltered silently (no warning).
632668
633669
Returns
634670
-------
@@ -694,12 +730,16 @@ def join_spatialelement_table(
694730
for name, element in getattr(derived_sdata, element_type).items():
695731
elements_dict[element_type][name] = element
696732

697-
elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows)
733+
elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows, filter_label_pixels)
698734
return elements_dict_joined, table
699735

700736

701737
def _call_join(
702-
elements_dict: dict[str, dict[str, Any]], table: AnnData, how: str, match_rows: Literal["no", "left", "right"]
738+
elements_dict: dict[str, dict[str, Any]],
739+
table: AnnData,
740+
how: str,
741+
match_rows: Literal["no", "left", "right"],
742+
filter_label_pixels: bool | None = None,
703743
) -> tuple[dict[str, Any], AnnData]:
704744
assert any(key in elements_dict for key in ["labels", "shapes", "points"]), (
705745
"No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or "
@@ -714,7 +754,7 @@ def _call_join(
714754
# if how in JoinTypes.__dict__["_member_names_"]:
715755
# hotfix for bug with Python 3.13:
716756
if how in JoinTypes.__dict__:
717-
elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows)
757+
elements_dict, table = getattr(JoinTypes, how)(elements_dict, table, match_rows, filter_label_pixels)
718758
else:
719759
raise TypeError(f"`{how}` is not a valid type of join.")
720760

@@ -797,6 +837,7 @@ def match_sdata_to_table(
797837
table_name: str,
798838
table: AnnData | None = None,
799839
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
840+
filter_label_pixels: bool | None = None,
800841
) -> SpatialData:
801842
"""
802843
Filter the elements of a SpatialData object to match only the rows present in the table.
@@ -812,6 +853,10 @@ def match_sdata_to_table(
812853
`table_name` is used to name the table in the returned `SpatialData` object.
813854
how
814855
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".
856+
filter_label_pixels
857+
Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them
858+
unfiltered and warns, ``False`` leaves them unfiltered silently. See
859+
:func:`spatialdata.join_spatialelement_table` for details.
815860
816861
Notes
817862
-----
@@ -823,7 +868,7 @@ def match_sdata_to_table(
823868
_, region_key, instance_key = get_table_keys(table)
824869
annotated_regions = SpatialData.get_annotated_regions(table)
825870
filtered_elements, filtered_table = join_spatialelement_table(
826-
sdata, spatial_element_names=annotated_regions, table=table, how=how
871+
sdata, spatial_element_names=annotated_regions, table=table, how=how, filter_label_pixels=filter_label_pixels
827872
)
828873
filtered_table = TableModel.parse(
829874
filtered_table,
@@ -847,6 +892,7 @@ def filter_by_table_query(
847892
var_names_expr: Predicates | None = None,
848893
layer: str | None = None,
849894
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
895+
filter_label_pixels: bool | None = None,
850896
) -> SpatialData:
851897
"""Filter the SpatialData object based on a set of table queries.
852898
@@ -875,6 +921,10 @@ def filter_by_table_query(
875921
The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`.
876922
how
877923
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".
924+
filter_label_pixels
925+
Controls pixel-level filtering of label elements. ``True`` filters pixels, ``None`` (default) leaves them
926+
unfiltered and warns, ``False`` leaves them unfiltered silently. See
927+
:func:`spatialdata.join_spatialelement_table` for details.
878928
879929
Returns
880930
-------
@@ -899,7 +949,13 @@ def filter_by_table_query(
899949
obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer
900950
)
901951

902-
return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how)
952+
return match_sdata_to_table(
953+
sdata=sdata_subset,
954+
table_name=table_name,
955+
table=filtered_table,
956+
how=how,
957+
filter_label_pixels=filter_label_pixels,
958+
)
903959

904960

905961
@dataclass
@@ -1099,3 +1155,50 @@ def get_values(
10991155
return df
11001156

11011157
raise ValueError(f"Unknown origin {origin}")
1158+
1159+
1160+
def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
1161+
# Use apply_ufunc for efficient processing
1162+
# Create a copy to avoid modifying read-only array
1163+
result = block.copy()
1164+
result[np.isin(result, ids_to_remove)] = 0
1165+
return result
1166+
1167+
1168+
def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
1169+
processed = xr.apply_ufunc(
1170+
partial(_mask_block, ids_to_remove=ids_to_remove),
1171+
image,
1172+
input_core_dims=[["y", "x"]],
1173+
output_core_dims=[["y", "x"]],
1174+
vectorize=True,
1175+
dask="parallelized",
1176+
output_dtypes=[image.dtype],
1177+
dataset_fill_value=0,
1178+
dask_gufunc_kwargs={"allow_rechunk": True},
1179+
)
1180+
1181+
# Create a new DataArray to ensure persistence
1182+
return xr.DataArray(
1183+
data=processed.data,
1184+
coords=image.coords,
1185+
dims=image.dims,
1186+
attrs=image.attrs.copy(), # Preserve all attributes
1187+
)
1188+
1189+
1190+
def _filter_labels_element(element: DataArray | DataTree, ids_to_keep: list[int]) -> DataArray | DataTree:
1191+
if get_model(element) is Labels3DModel:
1192+
raise NotImplementedError("Pixel-level filtering of 3D labels is not supported.")
1193+
element_instances = get_element_instances(element)
1194+
ids_to_remove = [i for i in element_instances if i not in set(ids_to_keep)]
1195+
if isinstance(element, DataArray):
1196+
return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove))
1197+
scales = list(element.keys())
1198+
scale_factors = [
1199+
round(element[scales[i]].image.shape[0] / element[scales[i + 1]].image.shape[0]) for i in range(len(scales) - 1)
1200+
]
1201+
return Labels2DModel.parse(
1202+
_set_instance_ids_in_labels_to_zero(element[scales[0]].image, ids_to_remove),
1203+
scale_factors=scale_factors,
1204+
)

0 commit comments

Comments
 (0)