11from __future__ import annotations
22
3- import warnings
43from abc import abstractmethod
54from collections .abc import Callable , Mapping
65from dataclasses import dataclass
98
109import dask .dataframe as dd
1110import numpy as np
11+ import pandas as pd
1212from dask .dataframe import DataFrame as DaskDataFrame
1313from geopandas import GeoDataFrame
1414from shapely .geometry import MultiPolygon , Point , Polygon
@@ -78,7 +78,7 @@ def _get_bounding_box_corners_in_intrinsic_coordinates(
7878
7979 # compute the output axes of the transformation, remove c from input and output axes, return the matrix without c
8080 # and then build an affine transformation from that
81- m_without_c , input_axes_without_c , output_axes_without_c = _get_axes_of_tranformation (
81+ m_without_c , input_axes_without_c , output_axes_without_c = _get_axes_of_transformation (
8282 element , target_coordinate_system
8383 )
8484 spatial_transform = Affine (m_without_c , input_axes = input_axes_without_c , output_axes = output_axes_without_c )
@@ -142,7 +142,7 @@ def _get_polygon_in_intrinsic_coordinates(
142142
143143 polygon_gdf = ShapesModel .parse (GeoDataFrame (geometry = [polygon ]))
144144
145- m_without_c , input_axes_without_c , output_axes_without_c = _get_axes_of_tranformation (
145+ m_without_c , input_axes_without_c , output_axes_without_c = _get_axes_of_transformation (
146146 element , target_coordinate_system
147147 )
148148 spatial_transform = Affine (m_without_c , input_axes = input_axes_without_c , output_axes = output_axes_without_c )
@@ -186,7 +186,7 @@ def _get_polygon_in_intrinsic_coordinates(
186186 return transform (polygon_gdf , to_coordinate_system = "inverse" )
187187
188188
189- def _get_axes_of_tranformation (
189+ def _get_axes_of_transformation (
190190 element : SpatialElement , target_coordinate_system : str
191191) -> tuple [ArrayLike , tuple [str , ...], tuple [str , ...]]:
192192 """
@@ -321,6 +321,11 @@ def _get_case_of_bounding_box_query(
321321 return case
322322
323323
324+ def _is_scaling_transform (m_linear : np .ndarray ) -> bool :
325+ """Check if the linear part is a diagonal (pure scaling) matrix."""
326+ return np .allclose (m_linear , np .diag (np .diagonal (m_linear )))
327+
328+
324329@dataclass (frozen = True )
325330class BaseSpatialRequest :
326331 """Base class for spatial queries."""
@@ -382,7 +387,7 @@ def to_dict(self) -> dict[str, Any]:
382387
383388@docstring_parameter (min_coordinate_docs = MIN_COORDINATE_DOCS , max_coordinate_docs = MAX_COORDINATE_DOCS )
384389def _bounding_box_mask_points (
385- points : DaskDataFrame ,
390+ points_df : pd . DataFrame ,
386391 axes : tuple [str , ...],
387392 min_coordinate : list [Number ] | ArrayLike ,
388393 max_coordinate : list [Number ] | ArrayLike ,
@@ -391,8 +396,8 @@ def _bounding_box_mask_points(
391396
392397 Parameters
393398 ----------
394- points
395- The points element to perform the query on.
399+ points_df
400+ A pre-computed pandas dataframe representing the points element to perform the query on.
396401 axes
397402 The axes that min_coordinate and max_coordinate refer to.
398403 min_coordinate
@@ -405,30 +410,28 @@ def _bounding_box_mask_points(
405410 Shape: (n_boxes, n_axes) or (n_axes,) for a single box.
406411 {max_coordinate_docs}
407412
413+
408414 Returns
409415 -------
410416 The masks for the points inside the bounding boxes.
411417 """
412- element_axes = get_axes_names (points )
413-
418+ element_axes = get_axes_names (points_df )
414419 min_coordinate = _parse_list_into_array (min_coordinate )
415420 max_coordinate = _parse_list_into_array (max_coordinate )
416-
417- # Ensure min_coordinate and max_coordinate are 2D arrays
418421 min_coordinate = min_coordinate [np .newaxis , :] if min_coordinate .ndim == 1 else min_coordinate
419422 max_coordinate = max_coordinate [np .newaxis , :] if max_coordinate .ndim == 1 else max_coordinate
420423
421424 n_boxes = min_coordinate .shape [0 ]
422425 in_bounding_box_masks = []
423-
424426 for box in range (n_boxes ):
425427 box_masks = []
426428 for axis_index , axis_name in enumerate (axes ):
427429 if axis_name not in element_axes :
428430 continue
429431 min_value = min_coordinate [box , axis_index ]
430432 max_value = max_coordinate [box , axis_index ]
431- box_masks .append (points [axis_name ].gt (min_value ).compute () & points [axis_name ].lt (max_value ).compute ())
433+ col = points_df [axis_name ].values
434+ box_masks .append ((col > min_value ) & (col < max_value ))
432435 bounding_box_mask = np .stack (box_masks , axis = - 1 )
433436 in_bounding_box_masks .append (np .all (bounding_box_mask , axis = 1 ))
434437 return in_bounding_box_masks
@@ -514,16 +517,6 @@ def _(
514517 min_coordinate = _parse_list_into_array (min_coordinate )
515518 max_coordinate = _parse_list_into_array (max_coordinate )
516519 new_elements = {}
517- if sdata .points :
518- warnings .warn (
519- (
520- "The object has `points` element. Depending on the number of points, querying MAY suffer from "
521- "performance issues. Please consider filtering the object before calling this function by calling the "
522- "`subset()` method of `SpatialData`."
523- ),
524- UserWarning ,
525- stacklevel = 2 ,
526- )
527520 for element_type in ["points" , "images" , "labels" , "shapes" ]:
528521 elements = getattr (sdata , element_type )
529522 queried_elements = _dict_query_dispatcher (
@@ -630,7 +623,6 @@ def _(
630623 max_coordinate : list [Number ] | ArrayLike ,
631624 target_coordinate_system : str ,
632625) -> DaskDataFrame | list [DaskDataFrame ] | None :
633- from spatialdata import transform
634626 from spatialdata .transformations import get_transformation
635627
636628 min_coordinate = _parse_list_into_array (min_coordinate )
@@ -640,6 +632,7 @@ def _(
640632 min_coordinate = min_coordinate [np .newaxis , :] if min_coordinate .ndim == 1 else min_coordinate
641633 max_coordinate = max_coordinate [np .newaxis , :] if max_coordinate .ndim == 1 else max_coordinate
642634
635+ # the code below is taken from _get_bounding_box_corners_in_intrinsic_coordinates()
643636 # for triggering validation
644637 _ = BoundingBoxRequest (
645638 target_coordinate_system = target_coordinate_system ,
@@ -648,100 +641,101 @@ def _(
648641 max_coordinate = max_coordinate ,
649642 )
650643
651- # get the four corners of the bounding box (2D case), or the 8 corners of the "3D bounding box" (3D case)
652- (intrinsic_bounding_box_corners , intrinsic_axes ) = _get_bounding_box_corners_in_intrinsic_coordinates (
653- element = points ,
654- axes = axes ,
655- min_coordinate = min_coordinate ,
656- max_coordinate = max_coordinate ,
657- target_coordinate_system = target_coordinate_system ,
644+ m_without_c , input_axes_without_c , output_axes_without_c = _get_axes_of_transformation (
645+ points , target_coordinate_system
658646 )
659- min_coordinate_intrinsic = intrinsic_bounding_box_corners .min (dim = "corner" )
660- max_coordinate_intrinsic = intrinsic_bounding_box_corners .max (dim = "corner" )
661-
662- min_coordinate_intrinsic = min_coordinate_intrinsic .data
663- max_coordinate_intrinsic = max_coordinate_intrinsic .data
664-
665- # get the points in the intrinsic coordinate bounding box
666- in_intrinsic_bounding_box = _bounding_box_mask_points (
667- points = points ,
668- axes = intrinsic_axes ,
669- min_coordinate = min_coordinate_intrinsic ,
670- max_coordinate = max_coordinate_intrinsic ,
647+ m_without_c_linear = m_without_c [:- 1 , :- 1 ]
648+ _ = _get_case_of_bounding_box_query (
649+ m_without_c_linear ,
650+ input_axes_without_c ,
651+ output_axes_without_c ,
652+ )
653+ axes_adjusted , min_coordinate_adjusted , max_coordinate_adjusted = _adjust_bounding_box_to_real_axes (
654+ axes ,
655+ min_coordinate ,
656+ max_coordinate ,
657+ output_axes_without_c ,
658+ )
659+ if set (axes_adjusted ) != set (output_axes_without_c ):
660+ raise ValueError ("The axes of the bounding box must match the axes of the transformation." )
661+
662+ # materialize the points in the intrinsic coordinate system once
663+ points_pd = points .compute ()
664+
665+ # checking the type of the transformation
666+ # in the case of an identity or scaling transform, we can skip the whole
667+ # projection into intrinsic space and reprojection into the global coordinate system
668+ is_identity_transform = input_axes_without_c == output_axes_without_c and np .allclose (
669+ m_without_c , np .eye (m_without_c .shape [0 ])
671670 )
671+ is_scaling_transform = input_axes_without_c == output_axes_without_c and _is_scaling_transform (m_without_c_linear )
672+
673+ # if the transform is identity, we can save extra for the affine transformation
674+ if is_identity_transform :
675+ bounding_box_masks = _bounding_box_mask_points (
676+ points_df = points_pd ,
677+ axes = axes_adjusted ,
678+ min_coordinate = min_coordinate_adjusted ,
679+ max_coordinate = max_coordinate_adjusted ,
680+ )
681+ elif is_scaling_transform :
682+ # Pull scale factors from the diagonal and the translation from the last column
683+ scales = np .diagonal (m_without_c_linear ) # shape: (n_axes,)
684+ translation = m_without_c [:- 1 , - 1 ] # shape: (n_axes,)
685+
686+ # Invert the affine: x_intrinsic = (x_output - translation) / scale
687+ min_intrinsic = (min_coordinate_adjusted - translation ) / scales
688+ max_intrinsic = (max_coordinate_adjusted - translation ) / scales
689+
690+ # Negative scale components flip the interval; restore min < max.
691+ min_intrinsic , max_intrinsic = (
692+ np .minimum (min_intrinsic , max_intrinsic ),
693+ np .maximum (min_intrinsic , max_intrinsic ),
694+ )
672695
673- if not (len_df := len (in_intrinsic_bounding_box )) == (len_bb := len (min_coordinate )):
674- raise ValueError (
675- f"Length of list of dataframes `{ len_df } ` is not equal to the number of bounding boxes axes `{ len_bb } `."
696+ bounding_box_masks = _bounding_box_mask_points (
697+ points_df = points_pd ,
698+ axes = tuple (input_axes_without_c ),
699+ min_coordinate = min_intrinsic ,
700+ max_coordinate = max_intrinsic ,
676701 )
677- points_in_intrinsic_bounding_box : list [DaskDataFrame | None ] = []
678- points_pd = points .compute ()
679- attrs = points .attrs .copy ()
680- for mask_np in in_intrinsic_bounding_box :
681- if mask_np .sum () == 0 :
682- points_in_intrinsic_bounding_box .append (None )
683- else :
684- # TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now.
685- # we can't compute either mask or points as when we calculate either one of them
686- # test_query_points_multiple_partitions will fail as the mask will be used to index each partition.
687- # However, if we compute and then create the dask array again we get the mixed dask graph problem.
688- filtered_pd = points_pd [mask_np ]
689- points_filtered = dd .from_pandas (filtered_pd , npartitions = points .npartitions )
690- points_filtered .attrs .update (attrs )
691- points_in_intrinsic_bounding_box .append (points_filtered )
692- if len (points_in_intrinsic_bounding_box ) == 0 :
693- return None
702+ else :
703+ query_coordinates = points_pd .loc [:, list (input_axes_without_c )].to_numpy (copy = False )
704+ query_coordinates = query_coordinates @ m_without_c [:- 1 , :- 1 ].T + m_without_c [:- 1 , - 1 ]
705+
706+ bounding_box_masks = []
707+ for box_index in range (min_coordinate_adjusted .shape [0 ]):
708+ bounding_box_mask = np .ones (len (points_pd ), dtype = bool )
709+ for axis_index in range (len (output_axes_without_c )):
710+ min_value = min_coordinate_adjusted [box_index , axis_index ]
711+ max_value = max_coordinate_adjusted [box_index , axis_index ]
712+ column = query_coordinates [:, axis_index ]
713+ bounding_box_mask &= (column > min_value ) & (column < max_value )
714+ bounding_box_masks .append (bounding_box_mask )
715+
716+ if not (len_df := len (bounding_box_masks )) == (len_bb := len (min_coordinate )):
717+ raise ValueError (f"Length of list of masks `{ len_df } ` is not equal to the number of bounding boxes `{ len_bb } `." )
718+
719+ old_transformations = get_transformation (points , get_all = True )
720+ assert isinstance (old_transformations , dict )
721+ feature_key = points .attrs .get (ATTRS_KEY , {}).get (PointsModel .FEATURE_KEY )
694722
695- # assert that the number of queried points is correct
696- assert len (points_in_intrinsic_bounding_box ) == len (min_coordinate )
697-
698- # # we have to reset the index since we have subset
699- # # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask
700- # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.assign(idx=1)
701- # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.set_index(
702- # points_in_intrinsic_bounding_box.idx.cumsum() - 1
703- # )
704- # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.map_partitions(
705- # lambda df: df.rename(index={"idx": None})
706- # )
707- # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"])
708-
709- # transform the element to the query coordinate system
710723 output : list [DaskDataFrame | None ] = []
711- for p , min_c , max_c in zip (points_in_intrinsic_bounding_box , min_coordinate , max_coordinate , strict = True ):
712- if p is None :
724+ for mask_np in bounding_box_masks :
725+ bounding_box_indices = np .flatnonzero (mask_np )
726+ if len (bounding_box_indices ) == 0 :
713727 output .append (None )
714- else :
715- points_query_coordinate_system = transform (
716- p , to_coordinate_system = target_coordinate_system , maintain_positioning = False
728+ continue
729+
730+ # The exact mask is computed in the query coordinate system, but the returned points must stay intrinsic.
731+ queried_points = points_pd .iloc [bounding_box_indices ]
732+ output .append (
733+ PointsModel .parse (
734+ dd .from_pandas (queried_points , npartitions = points .npartitions ),
735+ transformations = old_transformations .copy (),
736+ feature_key = feature_key ,
717737 )
718-
719- # get a mask for the points in the bounding box
720- bounding_box_mask = _bounding_box_mask_points (
721- points = points_query_coordinate_system ,
722- axes = axes ,
723- min_coordinate = min_c , # type: ignore[arg-type]
724- max_coordinate = max_c , # type: ignore[arg-type]
725- )
726- if len (bounding_box_mask ) != 1 :
727- raise ValueError (f"Expected a single mask, got { len (bounding_box_mask )} masks. Please report this bug." )
728- bounding_box_indices = np .where (bounding_box_mask [0 ])[0 ]
729-
730- if len (bounding_box_indices ) == 0 :
731- output .append (None )
732- else :
733- points_df = p .compute ().iloc [bounding_box_indices ]
734- old_transformations = get_transformation (p , get_all = True )
735- assert isinstance (old_transformations , dict )
736- feature_key = p .attrs .get (ATTRS_KEY , {}).get (PointsModel .FEATURE_KEY )
737-
738- output .append (
739- PointsModel .parse (
740- dd .from_pandas (points_df , npartitions = 1 ),
741- transformations = old_transformations .copy (),
742- feature_key = feature_key ,
743- )
744- )
738+ )
745739 if len (output ) == 0 :
746740 return None
747741 if len (output ) == 1 :
@@ -791,8 +785,8 @@ def _(
791785 )
792786 for box_corners in intrinsic_bounding_box_corners :
793787 bounding_box_non_axes_aligned = Polygon (box_corners .data )
794- indices = polygons .geometry . intersects (bounding_box_non_axes_aligned )
795- queried = polygons [ indices ]
788+ candidate_idx = polygons .sindex . query (bounding_box_non_axes_aligned , predicate = "intersects" )
789+ queried = polygons . iloc [ candidate_idx ]
796790 if len (queried ) == 0 :
797791 queried_polygon = None
798792 else :
@@ -949,17 +943,22 @@ def _(
949943 assert np .all (element [OLD_INDEX ] == buffered .index )
950944 else :
951945 buffered [OLD_INDEX ] = buffered .index
952- indices = buffered .geometry .apply (lambda x : x .intersects (polygon ))
953- if np .sum (indices ) == 0 :
946+
947+ # Use sindex for fast candidate pre-filtering, then exact intersection check
948+ # only on the (typically small) candidate set — same pattern as bounding_box_query.
949+ candidate_idx = buffered .sindex .query (polygon , predicate = "intersects" )
950+ if len (candidate_idx ) == 0 :
951+ del buffered [OLD_INDEX ]
954952 return None
955- queried_shapes = element [indices ]
956- queried_shapes .index = buffered [indices ][OLD_INDEX ]
953+
954+ queried_shapes = element .iloc [candidate_idx ].copy ()
955+ queried_shapes .index = buffered .iloc [candidate_idx ][OLD_INDEX ]
957956 queried_shapes .index .name = None
958957
959958 if clip :
960959 if isinstance (element .geometry .iloc [0 ], Point ):
961- queried_shapes = buffered [ indices ]
962- queried_shapes .index = buffered [ indices ][OLD_INDEX ]
960+ queried_shapes = buffered . iloc [ candidate_idx ]
961+ queried_shapes .index = buffered . iloc [ candidate_idx ][OLD_INDEX ]
963962 queried_shapes .index .name = None
964963 queried_shapes = queried_shapes .clip (polygon_gdf , keep_geom_type = True )
965964
0 commit comments