@@ -59,31 +59,6 @@ def get_element_annotators(sdata: SpatialData, element_name: str) -> set[str]:
5959 return table_names
6060
6161
62- def _filter_table_by_element_names (table : AnnData | None , element_names : str | list [str ]) -> AnnData | None :
63- """
64- Filter an AnnData table to keep only the rows that are in the coordinate system.
65-
66- Parameters
67- ----------
68- table
69- The table to filter; if None, returns None
70- element_names
71- The element_names to keep in the tables obs.region column
72-
73- Returns
74- -------
75- The filtered table, or None if the input table was None
76- """
77- if table is None or not table .uns .get (TableModel .ATTRS_KEY ):
78- return None
79- table_mapping_metadata = table .uns [TableModel .ATTRS_KEY ]
80- region_key = table_mapping_metadata [TableModel .REGION_KEY_KEY ]
81- table .obs = pd .DataFrame (table .obs )
82- table = table [table .obs [region_key ].isin (element_names )].copy ()
83- table .uns [TableModel .ATTRS_KEY ][TableModel .REGION_KEY ] = table .obs [region_key ].unique ().tolist ()
84- return table
85-
86-
8762@singledispatch
8863def get_element_instances (
8964 element : SpatialElement ,
@@ -113,7 +88,8 @@ def _(
11388 return_background : bool = False ,
11489) -> pd .Index :
11590 model = get_model (element )
116- assert model in [Labels2DModel , Labels3DModel ], "Expected a `Labels` element. Found an `Image` instead."
91+ if model not in [Labels2DModel , Labels3DModel ]:
92+ raise ValueError ("Expected a `Labels` element. Found an `Image` instead." )
11793 if isinstance (element , DataArray ):
11894 # get unique labels value (including 0 if present)
11995 instances = da .unique (element .data ).compute ()
@@ -144,88 +120,43 @@ def _(
144120 return element .index
145121
146122
147- # TODO: replace function use throughout repo by `join_sdata_spatialelement_table`
148- # TODO: benchmark against join operations before removing
149- def _filter_table_by_elements (
150- table : AnnData | None , elements_dict : dict [str , dict [str , Any ]], match_rows : bool = False
151- ) -> AnnData | None :
123+ def _filter_table_by_elements (table : AnnData | None , elements_dict : dict [str , dict [str , Any ]]) -> AnnData | None :
152124 """
153- Filter an AnnData table to keep only the rows that are in the elements .
125+ Filter an AnnData table to keep only the rows annotating elements in elements_dict .
154126
155127 Parameters
156128 ----------
157129 table
158130 The table to filter; if None, returns None
159131 elements_dict
160- The elements to use to filter the table
161- match_rows
162- If True, reorder the table rows to match the order of the elements
132+ The elements to use to filter the table, structured as ``{element_type: {name: element}}``.
133+ Image elements are ignored since tables cannot annotate images.
163134
164135 Returns
165136 -------
166- The filtered table (eventually with reordered rows) , or None if the input table was None.
137+ The filtered table, or None if the input table is None or no rows match .
167138 """
168- assert set (elements_dict .keys ()).issubset ({"images" , "labels" , "shapes" , "points" })
169- assert len (elements_dict ) > 0 , "elements_dict must not be empty"
170- assert any (len (elements ) > 0 for elements in elements_dict .values ()), (
171- "elements_dict must contain at least one dict which contains at least one element"
172- )
173139 if table is None :
174140 return None
175- to_keep = np .zeros (len (table ), dtype = bool )
176- region_key = table .uns [TableModel .ATTRS_KEY ][TableModel .REGION_KEY_KEY ]
177- instance_key = table .uns [TableModel .ATTRS_KEY ][TableModel .INSTANCE_KEY ]
178- instances = None
179- for _ , elements in elements_dict .items ():
180- for name , element in elements .items ():
181- if get_model (element ) == Labels2DModel or get_model (element ) == Labels3DModel :
182- if isinstance (element , DataArray ):
183- # get unique labels value (including 0 if present)
184- instances = da .unique (element .data ).compute ()
185- else :
186- assert isinstance (element , DataTree )
187- v = element ["scale0" ].values ()
188- assert len (v ) == 1
189- xdata = next (iter (v ))
190- # can be slow
191- instances = da .unique (xdata .data ).compute ()
192- instances = np .sort (instances )
193- elif get_model (element ) == ShapesModel :
194- instances = element .index .to_numpy ()
195- elif get_model (element ) == PointsModel :
196- instances = element .compute ().index .to_numpy ()
197- else :
198- continue
199- indices = ((table .obs [region_key ] == name ) & (table .obs [instance_key ].isin (instances ))).to_numpy ()
200- to_keep = to_keep | indices
201- original_table = table
202- table .obs = pd .DataFrame (table .obs )
203- table = table [to_keep , :]
204- if match_rows :
205- assert instances is not None
206- assert isinstance (instances , np .ndarray )
207- assert np .sum (to_keep ) != 0 , "No row matches in the table annotates the element"
208- if np .sum (to_keep ) != len (instances ):
209- if len (elements_dict ) > 1 or len (elements_dict ) == 1 and len (next (iter (elements_dict .values ()))) > 1 :
210- raise NotImplementedError ("Sorting is not supported when filtering by multiple elements" )
211- # case in which the instances in the table and the instances in the element don't correspond
212- assert "element" in locals ()
213- assert "name" in locals ()
214- n0 = np .setdiff1d (instances , table .obs [instance_key ].to_numpy ())
215- n1 = np .setdiff1d (table .obs [instance_key ].to_numpy (), instances )
216- assert len (n1 ) == 0 , f"The table contains { len (n1 )} instances that are not in the element: { n1 } "
217- # some instances have not a corresponding row in the table
218- instances = np .setdiff1d (instances , n0 )
219- assert np .sum (to_keep ) == len (instances )
220- assert sorted (set (instances .tolist ())) == sorted (set (table .obs [instance_key ].tolist ())) # type: ignore[type-var]
221- table_df = pd .DataFrame ({instance_key : table .obs [instance_key ], "position" : np .arange (len (instances ))})
222- merged = pd .merge (table_df , pd .DataFrame (index = instances ), left_on = instance_key , right_index = True , how = "right" )
223- matched_positions = merged ["position" ].to_numpy ()
224- table = table [matched_positions , :]
225- table = table .copy ()
226- _inplace_fix_subset_categorical_obs (subset_adata = table , original_adata = original_table )
227- table .uns [TableModel .ATTRS_KEY ][TableModel .REGION_KEY ] = table .obs [region_key ].unique ().tolist ()
228- return table
141+ elements_by_name = {
142+ name : element
143+ for element_type , name_to_element in elements_dict .items ()
144+ if element_type != "images"
145+ for name , element in name_to_element .items ()
146+ }
147+ if not elements_by_name :
148+ return None
149+ # Suppress "element not annotated by table" warnings: the table may annotate
150+ # only a subset of the elements passed in, which is expected here.
151+ with warnings .catch_warnings ():
152+ warnings .simplefilter ("ignore" , UserWarning )
153+ _ , filtered = join_spatialelement_table (
154+ spatial_element_names = list (elements_by_name .keys ()),
155+ spatial_elements = list (elements_by_name .values ()),
156+ table = table ,
157+ how = "left" ,
158+ )
159+ return filtered if filtered is not None and len (filtered ) > 0 else None
229160
230161
231162def _get_joined_table_indices (
@@ -253,7 +184,7 @@ def _get_joined_table_indices(
253184 -------
254185 The indices that of the table that match the SpatialElement indices.
255186 """
256- mask = table_instance_key_column .isin (element_indices )
187+ mask = np .isin (table_instance_key_column . values , element_indices )
257188 if joined_indices is None :
258189 if match_rows == "left" :
259190 _ , joined_indices = _match_rows (table_instance_key_column , mask , element_indices , match_rows )
@@ -294,7 +225,7 @@ def _get_masked_element(
294225 -------
295226 The masked spatial element based on the provided indices and match rows.
296227 """
297- mask = table_instance_key_column .isin (element_indices )
228+ mask = np .isin (table_instance_key_column . values , element_indices )
298229 masked_table_instance_key_column = table_instance_key_column [mask ]
299230 mask_values = mask_values if len (mask_values := masked_table_instance_key_column .values ) != 0 else None
300231 if match_rows in ["left" , "right" ]:
@@ -320,8 +251,11 @@ def _right_exclusive_join_spatialelement_table(
320251 regions , region_column_name , instance_key = get_table_keys (table )
321252 if isinstance (regions , str ):
322253 regions = [regions ]
323- groups_df = table .obs .groupby (by = region_column_name , observed = False )
324- mask = []
254+ # reset_index so group_df.index gives integer positions — safe with duplicate obs names
255+ obs = table .obs .reset_index ()
256+ groups_df = obs .groupby (by = region_column_name , observed = False )
257+ keep = np .zeros (len (table ), dtype = bool )
258+ has_match = False
325259 for element_type , name_element in element_dict .items ():
326260 for name , element in name_element .items ():
327261 if name in regions :
@@ -331,24 +265,23 @@ def _right_exclusive_join_spatialelement_table(
331265 element_indices = element .index
332266 else :
333267 element_indices = get_element_instances (element )
334-
335- element_dict [element_type ][name ] = None
336268 submask = ~ table_instance_key_column .isin (element_indices )
337- mask .append (submask )
269+ keep [group_df .index [submask .values ]] = True
270+ has_match = True
271+ element_dict [element_type ][name ] = None
338272 else :
339273 warnings .warn (
340274 f"The element `{ name } ` is not annotated by the table. Skipping" , UserWarning , stacklevel = 2
341275 )
342276 element_dict [element_type ][name ] = None
343277 continue
344278
345- if len (mask ) != 0 :
346- mask = pd .concat (mask )
347- exclusive_table = table [mask , :].copy () if mask .sum () != 0 else None # type: ignore[attr-defined]
348- else :
349- exclusive_table = None
350-
279+ exclusive_table = table [keep , :] if has_match and keep .any () else None
351280 _inplace_fix_subset_categorical_obs (subset_adata = exclusive_table , original_adata = table )
281+ if exclusive_table is not None :
282+ exclusive_table .uns [TableModel .ATTRS_KEY ][TableModel .REGION_KEY ] = (
283+ exclusive_table .obs [region_column_name ].unique ().tolist ()
284+ )
352285 return element_dict , exclusive_table
353286
354287
@@ -449,6 +382,10 @@ def _inner_join_spatialelement_table(
449382 joined_table = table [joined_indices .tolist (), :].copy () if joined_indices is not None else None
450383
451384 _inplace_fix_subset_categorical_obs (subset_adata = joined_table , original_adata = table )
385+ if joined_table is not None :
386+ joined_table .uns [TableModel .ATTRS_KEY ][TableModel .REGION_KEY ] = (
387+ joined_table .obs [region_column_name ].unique ().tolist ()
388+ )
452389 return element_dict , joined_table
453390
454391
@@ -528,7 +465,10 @@ def _left_join_spatialelement_table(
528465 joined_indices = joined_indices .astype (int )
529466 joined_table = table [joined_indices .tolist (), :].copy () if joined_indices is not None else None
530467 _inplace_fix_subset_categorical_obs (subset_adata = joined_table , original_adata = table )
531-
468+ if joined_table is not None :
469+ joined_table .uns [TableModel .ATTRS_KEY ][TableModel .REGION_KEY ] = (
470+ joined_table .obs [region_column_name ].unique ().tolist ()
471+ )
532472 return element_dict , joined_table
533473
534474
@@ -723,11 +663,16 @@ def join_spatialelement_table(
723663 if sdata is not None :
724664 elements_dict = _create_sdata_elements_dict_for_join (sdata , spatial_element_names )
725665 else :
726- derived_sdata = SpatialData .init_from_elements (dict (zip (spatial_element_names , spatial_elements , strict = True )))
727- element_types = ["labels" , "shapes" , "points" ]
666+ _model_to_type = {
667+ Labels2DModel : "labels" ,
668+ Labels3DModel : "labels" ,
669+ ShapesModel : "shapes" ,
670+ PointsModel : "points" ,
671+ }
728672 elements_dict = defaultdict (lambda : defaultdict (dict ))
729- for element_type in element_types :
730- for name , element in getattr (derived_sdata , element_type ).items ():
673+ for name , element in zip (spatial_element_names , spatial_elements , strict = True ):
674+ element_type = _model_to_type .get (get_model (element ))
675+ if element_type is not None :
731676 elements_dict [element_type ][name ] = element
732677
733678 elements_dict_joined , table = _call_join (elements_dict , table , how , match_rows , filter_label_pixels )
0 commit comments