Skip to content

Commit 620c77d

Browse files
Refactor relational queries (#1131)
* fix test_filter_by_coordinate_system_also_table from @giovp * remove unused _filter_table_in_coordinate_systems(); replace assert * Refactor _filter_table_by_elements to delegate to join_spatialelement_table - Replace the bespoke numpy implementation in _filter_table_by_elements with a call to join_spatialelement_table(how="left"), removing the unused match_rows parameter and consolidating to a single code path. - Fix join functions (_left_join, _inner_join, _right_exclusive_join) to update spatialdata_attrs region metadata after filtering. - Fix _right_exclusive_join: restore groupby pattern (consistent with other joins) and use reset_index so integer positions are used as the group index, avoiding IndexError when obs names are duplicated and fixing a latent bug where pd.concat of per-group masks produced a partial-length boolean mask. - Simplify _get_filtered_or_unfiltered_tables to use _filter_table_by_elements. - Replace SpatialData.init_from_elements with get_model() lookup when sdata=None in join_spatialelement_table, removing an expensive importlib call. - Add test_join_updates_spatialdata_attrs covering metadata update for all join types; update test_filter_by_table_query_edge_cases to match new behaviour. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove _filter_table_by_element_names and replace with _filter_table_by_elements The old function filtered only by region name, ignoring instance IDs. Replace its only caller (filter_tables by coordinate system) with _filter_table_by_elements, which correctly filters by both region and instance. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent accf496 commit 620c77d

6 files changed

Lines changed: 112 additions & 133 deletions

File tree

src/spatialdata/_core/concatenate.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,6 @@ def concatenate(
219219
return sdata
220220

221221

222-
def _filter_table_in_coordinate_systems(table: AnnData, coordinate_systems: list[str]) -> AnnData:
223-
table_mapping_metadata = table.uns[TableModel.ATTRS_KEY]
224-
region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY]
225-
new_table = table[table.obs[region_key].isin(coordinate_systems)].copy()
226-
new_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_table.obs[region_key].unique().tolist()
227-
return new_table
228-
229-
230222
def _fix_ensure_unique_element_names(
231223
sdatas: dict[str, SpatialData],
232224
rename_tables: bool,

src/spatialdata/_core/query/_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def _get_filtered_or_unfiltered_tables(
226226
return {
227227
name: filtered_table
228228
for name, table in sdata.tables.items()
229-
if (filtered_table := _filter_table_by_elements(table, elements)) and len(filtered_table) != 0
229+
if (filtered_table := _filter_table_by_elements(table, elements)) is not None
230230
}
231-
232231
return sdata.tables

src/spatialdata/_core/query/relational_query.py

Lines changed: 58 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -59,31 +59,6 @@ def get_element_annotators(sdata: SpatialData, element_name: str) -> set[str]:
5959
return table_names
6060

6161

62-
def _filter_table_by_element_names(table: AnnData | None, element_names: str | list[str]) -> AnnData | None:
63-
"""
64-
Filter an AnnData table to keep only the rows that are in the coordinate system.
65-
66-
Parameters
67-
----------
68-
table
69-
The table to filter; if None, returns None
70-
element_names
71-
The element_names to keep in the tables obs.region column
72-
73-
Returns
74-
-------
75-
The filtered table, or None if the input table was None
76-
"""
77-
if table is None or not table.uns.get(TableModel.ATTRS_KEY):
78-
return None
79-
table_mapping_metadata = table.uns[TableModel.ATTRS_KEY]
80-
region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY]
81-
table.obs = pd.DataFrame(table.obs)
82-
table = table[table.obs[region_key].isin(element_names)].copy()
83-
table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist()
84-
return table
85-
86-
8762
@singledispatch
8863
def get_element_instances(
8964
element: SpatialElement,
@@ -113,7 +88,8 @@ def _(
11388
return_background: bool = False,
11489
) -> pd.Index:
11590
model = get_model(element)
116-
assert model in [Labels2DModel, Labels3DModel], "Expected a `Labels` element. Found an `Image` instead."
91+
if model not in [Labels2DModel, Labels3DModel]:
92+
raise ValueError("Expected a `Labels` element. Found an `Image` instead.")
11793
if isinstance(element, DataArray):
11894
# get unique labels value (including 0 if present)
11995
instances = da.unique(element.data).compute()
@@ -144,88 +120,43 @@ def _(
144120
return element.index
145121

146122

147-
# TODO: replace function use throughout repo by `join_sdata_spatialelement_table`
148-
# TODO: benchmark against join operations before removing
149-
def _filter_table_by_elements(
150-
table: AnnData | None, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False
151-
) -> AnnData | None:
123+
def _filter_table_by_elements(table: AnnData | None, elements_dict: dict[str, dict[str, Any]]) -> AnnData | None:
152124
"""
153-
Filter an AnnData table to keep only the rows that are in the elements.
125+
Filter an AnnData table to keep only the rows annotating elements in elements_dict.
154126
155127
Parameters
156128
----------
157129
table
158130
The table to filter; if None, returns None
159131
elements_dict
160-
The elements to use to filter the table
161-
match_rows
162-
If True, reorder the table rows to match the order of the elements
132+
The elements to use to filter the table, structured as ``{element_type: {name: element}}``.
133+
Image elements are ignored since tables cannot annotate images.
163134
164135
Returns
165136
-------
166-
The filtered table (eventually with reordered rows), or None if the input table was None.
137+
The filtered table, or None if the input table is None or no rows match.
167138
"""
168-
assert set(elements_dict.keys()).issubset({"images", "labels", "shapes", "points"})
169-
assert len(elements_dict) > 0, "elements_dict must not be empty"
170-
assert any(len(elements) > 0 for elements in elements_dict.values()), (
171-
"elements_dict must contain at least one dict which contains at least one element"
172-
)
173139
if table is None:
174140
return None
175-
to_keep = np.zeros(len(table), dtype=bool)
176-
region_key = table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]
177-
instance_key = table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY]
178-
instances = None
179-
for _, elements in elements_dict.items():
180-
for name, element in elements.items():
181-
if get_model(element) == Labels2DModel or get_model(element) == Labels3DModel:
182-
if isinstance(element, DataArray):
183-
# get unique labels value (including 0 if present)
184-
instances = da.unique(element.data).compute()
185-
else:
186-
assert isinstance(element, DataTree)
187-
v = element["scale0"].values()
188-
assert len(v) == 1
189-
xdata = next(iter(v))
190-
# can be slow
191-
instances = da.unique(xdata.data).compute()
192-
instances = np.sort(instances)
193-
elif get_model(element) == ShapesModel:
194-
instances = element.index.to_numpy()
195-
elif get_model(element) == PointsModel:
196-
instances = element.compute().index.to_numpy()
197-
else:
198-
continue
199-
indices = ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy()
200-
to_keep = to_keep | indices
201-
original_table = table
202-
table.obs = pd.DataFrame(table.obs)
203-
table = table[to_keep, :]
204-
if match_rows:
205-
assert instances is not None
206-
assert isinstance(instances, np.ndarray)
207-
assert np.sum(to_keep) != 0, "No row matches in the table annotates the element"
208-
if np.sum(to_keep) != len(instances):
209-
if len(elements_dict) > 1 or len(elements_dict) == 1 and len(next(iter(elements_dict.values()))) > 1:
210-
raise NotImplementedError("Sorting is not supported when filtering by multiple elements")
211-
# case in which the instances in the table and the instances in the element don't correspond
212-
assert "element" in locals()
213-
assert "name" in locals()
214-
n0 = np.setdiff1d(instances, table.obs[instance_key].to_numpy())
215-
n1 = np.setdiff1d(table.obs[instance_key].to_numpy(), instances)
216-
assert len(n1) == 0, f"The table contains {len(n1)} instances that are not in the element: {n1}"
217-
# some instances have not a corresponding row in the table
218-
instances = np.setdiff1d(instances, n0)
219-
assert np.sum(to_keep) == len(instances)
220-
assert sorted(set(instances.tolist())) == sorted(set(table.obs[instance_key].tolist())) # type: ignore[type-var]
221-
table_df = pd.DataFrame({instance_key: table.obs[instance_key], "position": np.arange(len(instances))})
222-
merged = pd.merge(table_df, pd.DataFrame(index=instances), left_on=instance_key, right_index=True, how="right")
223-
matched_positions = merged["position"].to_numpy()
224-
table = table[matched_positions, :]
225-
table = table.copy()
226-
_inplace_fix_subset_categorical_obs(subset_adata=table, original_adata=original_table)
227-
table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist()
228-
return table
141+
elements_by_name = {
142+
name: element
143+
for element_type, name_to_element in elements_dict.items()
144+
if element_type != "images"
145+
for name, element in name_to_element.items()
146+
}
147+
if not elements_by_name:
148+
return None
149+
# Suppress "element not annotated by table" warnings: the table may annotate
150+
# only a subset of the elements passed in, which is expected here.
151+
with warnings.catch_warnings():
152+
warnings.simplefilter("ignore", UserWarning)
153+
_, filtered = join_spatialelement_table(
154+
spatial_element_names=list(elements_by_name.keys()),
155+
spatial_elements=list(elements_by_name.values()),
156+
table=table,
157+
how="left",
158+
)
159+
return filtered if filtered is not None and len(filtered) > 0 else None
229160

230161

231162
def _get_joined_table_indices(
@@ -253,7 +184,7 @@ def _get_joined_table_indices(
253184
-------
254185
The indices that of the table that match the SpatialElement indices.
255186
"""
256-
mask = table_instance_key_column.isin(element_indices)
187+
mask = np.isin(table_instance_key_column.values, element_indices)
257188
if joined_indices is None:
258189
if match_rows == "left":
259190
_, joined_indices = _match_rows(table_instance_key_column, mask, element_indices, match_rows)
@@ -294,7 +225,7 @@ def _get_masked_element(
294225
-------
295226
The masked spatial element based on the provided indices and match rows.
296227
"""
297-
mask = table_instance_key_column.isin(element_indices)
228+
mask = np.isin(table_instance_key_column.values, element_indices)
298229
masked_table_instance_key_column = table_instance_key_column[mask]
299230
mask_values = mask_values if len(mask_values := masked_table_instance_key_column.values) != 0 else None
300231
if match_rows in ["left", "right"]:
@@ -320,8 +251,11 @@ def _right_exclusive_join_spatialelement_table(
320251
regions, region_column_name, instance_key = get_table_keys(table)
321252
if isinstance(regions, str):
322253
regions = [regions]
323-
groups_df = table.obs.groupby(by=region_column_name, observed=False)
324-
mask = []
254+
# reset_index so group_df.index gives integer positions — safe with duplicate obs names
255+
obs = table.obs.reset_index()
256+
groups_df = obs.groupby(by=region_column_name, observed=False)
257+
keep = np.zeros(len(table), dtype=bool)
258+
has_match = False
325259
for element_type, name_element in element_dict.items():
326260
for name, element in name_element.items():
327261
if name in regions:
@@ -331,24 +265,23 @@ def _right_exclusive_join_spatialelement_table(
331265
element_indices = element.index
332266
else:
333267
element_indices = get_element_instances(element)
334-
335-
element_dict[element_type][name] = None
336268
submask = ~table_instance_key_column.isin(element_indices)
337-
mask.append(submask)
269+
keep[group_df.index[submask.values]] = True
270+
has_match = True
271+
element_dict[element_type][name] = None
338272
else:
339273
warnings.warn(
340274
f"The element `{name}` is not annotated by the table. Skipping", UserWarning, stacklevel=2
341275
)
342276
element_dict[element_type][name] = None
343277
continue
344278

345-
if len(mask) != 0:
346-
mask = pd.concat(mask)
347-
exclusive_table = table[mask, :].copy() if mask.sum() != 0 else None # type: ignore[attr-defined]
348-
else:
349-
exclusive_table = None
350-
279+
exclusive_table = table[keep, :] if has_match and keep.any() else None
351280
_inplace_fix_subset_categorical_obs(subset_adata=exclusive_table, original_adata=table)
281+
if exclusive_table is not None:
282+
exclusive_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = (
283+
exclusive_table.obs[region_column_name].unique().tolist()
284+
)
352285
return element_dict, exclusive_table
353286

354287

@@ -449,6 +382,10 @@ def _inner_join_spatialelement_table(
449382
joined_table = table[joined_indices.tolist(), :].copy() if joined_indices is not None else None
450383

451384
_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
385+
if joined_table is not None:
386+
joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = (
387+
joined_table.obs[region_column_name].unique().tolist()
388+
)
452389
return element_dict, joined_table
453390

454391

@@ -528,7 +465,10 @@ def _left_join_spatialelement_table(
528465
joined_indices = joined_indices.astype(int)
529466
joined_table = table[joined_indices.tolist(), :].copy() if joined_indices is not None else None
530467
_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
531-
468+
if joined_table is not None:
469+
joined_table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = (
470+
joined_table.obs[region_column_name].unique().tolist()
471+
)
532472
return element_dict, joined_table
533473

534474

@@ -723,11 +663,16 @@ def join_spatialelement_table(
723663
if sdata is not None:
724664
elements_dict = _create_sdata_elements_dict_for_join(sdata, spatial_element_names)
725665
else:
726-
derived_sdata = SpatialData.init_from_elements(dict(zip(spatial_element_names, spatial_elements, strict=True)))
727-
element_types = ["labels", "shapes", "points"]
666+
_model_to_type = {
667+
Labels2DModel: "labels",
668+
Labels3DModel: "labels",
669+
ShapesModel: "shapes",
670+
PointsModel: "points",
671+
}
728672
elements_dict = defaultdict(lambda: defaultdict(dict))
729-
for element_type in element_types:
730-
for name, element in getattr(derived_sdata, element_type).items():
673+
for name, element in zip(spatial_element_names, spatial_elements, strict=True):
674+
element_type = _model_to_type.get(get_model(element))
675+
if element_type is not None:
731676
elements_dict[element_type][name] = element
732677

733678
elements_dict_joined, table = _call_join(elements_dict, table, how, match_rows, filter_label_pixels)

src/spatialdata/_core/spatialdata.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,12 +691,17 @@ def _filter_tables(
691691
continue
692692
# each mode here requires paths or elements, using assert here to avoid mypy errors.
693693
if by == "cs":
694-
from spatialdata._core.query.relational_query import (
695-
_filter_table_by_element_names,
696-
)
694+
from spatialdata._core.query.relational_query import _filter_table_by_elements
697695

698696
assert element_names is not None
699-
table = _filter_table_by_element_names(table, element_names)
697+
elements_dict = {}
698+
for element_type in ["images", "labels", "shapes", "points"]:
699+
elements = getattr(self, element_type)
700+
if elements: # Check if the dictionary is not empty
701+
elements_dict[element_type] = {
702+
name: elements[name] for name in element_names if name in elements
703+
}
704+
table = _filter_table_by_elements(table, elements_dict=elements_dict)
700705
if table is not None and len(table) != 0:
701706
tables[table_name] = table
702707
elif by == "elements":

tests/core/operations/test_spatialdata_operations.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,14 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None:
156156
def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None:
157157
from spatialdata.models import TableModel
158158

159-
rng = np.random.default_rng(seed=0)
160-
full_sdata["table"].obs["annotated_shapes"] = pd.Categorical(
161-
rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0])
162-
)
163-
adata = full_sdata["table"]
159+
adata = full_sdata["table"].copy()
160+
161+
circles_instances = full_sdata["circles"].index.values
162+
poly_instances = full_sdata["poly"].index.values
163+
164+
adata = adata[: len(circles_instances) + len(poly_instances), :].copy()
165+
adata.obs["annotated_shapes"] = ["circles"] * len(circles_instances) + ["poly"] * len(poly_instances)
166+
adata.obs["instance_id"] = np.concatenate([circles_instances, poly_instances])
164167
del adata.uns[TableModel.ATTRS_KEY]
165168
full_sdata["table"] = TableModel.parse(
166169
adata,

tests/core/query/test_relational_query.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,41 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation):
246246
assert element_dict["by_polygons"] is None
247247

248248

249+
def test_join_updates_spatialdata_attrs(sdata_query_aggregation):
250+
sdata = sdata_query_aggregation
251+
# table annotates ["values_circles", "values_polygons"]
252+
original_regions = sdata["table"].uns["spatialdata_attrs"]["region"]
253+
assert set(original_regions) == {"values_circles", "values_polygons"}
254+
255+
# left join on a single element: region list must shrink to just that element
256+
_, table = join_spatialelement_table(
257+
sdata=sdata, spatial_element_names="values_circles", table_name="table", how="left"
258+
)
259+
assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"]
260+
261+
# inner join on a single element
262+
_, table = join_spatialelement_table(
263+
sdata=sdata, spatial_element_names="values_circles", table_name="table", how="inner"
264+
)
265+
assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"]
266+
267+
# right_exclusive join: pass a truncated circles element so some table rows have no match.
268+
# values_circles has 9 instances (0-8); keep only 5 → 4 table rows are exclusive.
269+
# Use sdata=None mode so we can pass a truncated element under the original region name.
270+
_, table = join_spatialelement_table(
271+
spatial_element_names=["values_circles"],
272+
spatial_elements=[sdata["values_circles"].iloc[:5]],
273+
table=sdata["table"],
274+
how="right_exclusive",
275+
)
276+
assert table is not None
277+
assert table.n_obs == 4
278+
assert table.uns["spatialdata_attrs"]["region"] == ["values_circles"]
279+
280+
# original table metadata must be unchanged
281+
assert set(sdata["table"].uns["spatialdata_attrs"]["region"]) == {"values_circles", "values_polygons"}
282+
283+
249284
def test_join_spatialelement_table_fail(full_sdata):
250285
with pytest.raises(ValueError, match=" not supported for join operation."):
251286
join_spatialelement_table(
@@ -1214,7 +1249,7 @@ def test_filter_by_table_query_edge_cases(complex_sdata):
12141249
assert var_name.startswith("feature_") and int(var_name.split("_")[1]) < 5
12151250

12161251
# Test 6: Invalid element_names (element doesn't exist)
1217-
with pytest.raises(AssertionError, match="elements_dict must not be empty"):
1252+
with pytest.raises(KeyError, match="shapes_table"):
12181253
filter_by_table_query(
12191254
sdata=sdata,
12201255
table_name="shapes_table",

0 commit comments

Comments
 (0)