Skip to content

Commit d8bf265

Browse files
MeyerBenderpre-commit-ci[bot]dschaub95LucaMarconatoclaude
authored
Speedup for bounding_box_query (#1104)
* Removed unnecessary compute() call * Using the R-tree for spatial querying of shapes * Cleanup * Cleanup * Cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Optimize transformed point bounding box queries and add multi-box coverage * add comments * Used spatial indexing also for polygon query * sped up querying for scaling transformation and removed warning about performance issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix ruff pre-commit violations in spatial_query.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * code review: remove points parameter in bounding box internal function; add extra tests for querying points; remove unnecessary copy() in polygon query of points * Fix bounding_box_query: restore negative-scale interval swap, use axes_adjusted in identity path, restore npartitions - Restore np.minimum/np.maximum swap so axis-flip transformations (negative scale) no longer raise ValueError; add regression test - Use axes_adjusted/min_coordinate_adjusted consistently in the identity path - Revert npartitions=1 back to points.npartitions in result construction - Add test for general affine transform (rotate 45° + translate) in aligned space Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: dschaub95 <schaub.darius@gmail.com> Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Co-authored-by: Luca Marconato <m.lucalmer@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent ea87931 commit d8bf265

3 files changed

Lines changed: 218 additions & 118 deletions

File tree

src/spatialdata/_core/query/spatial_query.py

Lines changed: 117 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import warnings
43
from abc import abstractmethod
54
from collections.abc import Callable, Mapping
65
from dataclasses import dataclass
@@ -9,6 +8,7 @@
98

109
import dask.dataframe as dd
1110
import numpy as np
11+
import pandas as pd
1212
from dask.dataframe import DataFrame as DaskDataFrame
1313
from geopandas import GeoDataFrame
1414
from shapely.geometry import MultiPolygon, Point, Polygon
@@ -78,7 +78,7 @@ def _get_bounding_box_corners_in_intrinsic_coordinates(
7878

7979
# compute the output axes of the transformation, remove c from input and output axes, return the matrix without c
8080
# and then build an affine transformation from that
81-
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation(
81+
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_transformation(
8282
element, target_coordinate_system
8383
)
8484
spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c)
@@ -142,7 +142,7 @@ def _get_polygon_in_intrinsic_coordinates(
142142

143143
polygon_gdf = ShapesModel.parse(GeoDataFrame(geometry=[polygon]))
144144

145-
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation(
145+
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_transformation(
146146
element, target_coordinate_system
147147
)
148148
spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c)
@@ -186,7 +186,7 @@ def _get_polygon_in_intrinsic_coordinates(
186186
return transform(polygon_gdf, to_coordinate_system="inverse")
187187

188188

189-
def _get_axes_of_tranformation(
189+
def _get_axes_of_transformation(
190190
element: SpatialElement, target_coordinate_system: str
191191
) -> tuple[ArrayLike, tuple[str, ...], tuple[str, ...]]:
192192
"""
@@ -321,6 +321,11 @@ def _get_case_of_bounding_box_query(
321321
return case
322322

323323

324+
def _is_scaling_transform(m_linear: np.ndarray) -> bool:
325+
"""Check if the linear part is a diagonal (pure scaling) matrix."""
326+
return np.allclose(m_linear, np.diag(np.diagonal(m_linear)))
327+
328+
324329
@dataclass(frozen=True)
325330
class BaseSpatialRequest:
326331
"""Base class for spatial queries."""
@@ -382,7 +387,7 @@ def to_dict(self) -> dict[str, Any]:
382387

383388
@docstring_parameter(min_coordinate_docs=MIN_COORDINATE_DOCS, max_coordinate_docs=MAX_COORDINATE_DOCS)
384389
def _bounding_box_mask_points(
385-
points: DaskDataFrame,
390+
points_df: pd.DataFrame,
386391
axes: tuple[str, ...],
387392
min_coordinate: list[Number] | ArrayLike,
388393
max_coordinate: list[Number] | ArrayLike,
@@ -391,8 +396,8 @@ def _bounding_box_mask_points(
391396
392397
Parameters
393398
----------
394-
points
395-
The points element to perform the query on.
399+
points_df
400+
A pre-computed pandas dataframe representing the points element to perform the query on.
396401
axes
397402
The axes that min_coordinate and max_coordinate refer to.
398403
min_coordinate
@@ -405,30 +410,28 @@ def _bounding_box_mask_points(
405410
Shape: (n_boxes, n_axes) or (n_axes,) for a single box.
406411
{max_coordinate_docs}
407412
413+
408414
Returns
409415
-------
410416
The masks for the points inside the bounding boxes.
411417
"""
412-
element_axes = get_axes_names(points)
413-
418+
element_axes = get_axes_names(points_df)
414419
min_coordinate = _parse_list_into_array(min_coordinate)
415420
max_coordinate = _parse_list_into_array(max_coordinate)
416-
417-
# Ensure min_coordinate and max_coordinate are 2D arrays
418421
min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate
419422
max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate
420423

421424
n_boxes = min_coordinate.shape[0]
422425
in_bounding_box_masks = []
423-
424426
for box in range(n_boxes):
425427
box_masks = []
426428
for axis_index, axis_name in enumerate(axes):
427429
if axis_name not in element_axes:
428430
continue
429431
min_value = min_coordinate[box, axis_index]
430432
max_value = max_coordinate[box, axis_index]
431-
box_masks.append(points[axis_name].gt(min_value).compute() & points[axis_name].lt(max_value).compute())
433+
col = points_df[axis_name].values
434+
box_masks.append((col > min_value) & (col < max_value))
432435
bounding_box_mask = np.stack(box_masks, axis=-1)
433436
in_bounding_box_masks.append(np.all(bounding_box_mask, axis=1))
434437
return in_bounding_box_masks
@@ -514,16 +517,6 @@ def _(
514517
min_coordinate = _parse_list_into_array(min_coordinate)
515518
max_coordinate = _parse_list_into_array(max_coordinate)
516519
new_elements = {}
517-
if sdata.points:
518-
warnings.warn(
519-
(
520-
"The object has `points` element. Depending on the number of points, querying MAY suffer from "
521-
"performance issues. Please consider filtering the object before calling this function by calling the "
522-
"`subset()` method of `SpatialData`."
523-
),
524-
UserWarning,
525-
stacklevel=2,
526-
)
527520
for element_type in ["points", "images", "labels", "shapes"]:
528521
elements = getattr(sdata, element_type)
529522
queried_elements = _dict_query_dispatcher(
@@ -630,7 +623,6 @@ def _(
630623
max_coordinate: list[Number] | ArrayLike,
631624
target_coordinate_system: str,
632625
) -> DaskDataFrame | list[DaskDataFrame] | None:
633-
from spatialdata import transform
634626
from spatialdata.transformations import get_transformation
635627

636628
min_coordinate = _parse_list_into_array(min_coordinate)
@@ -640,6 +632,7 @@ def _(
640632
min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate
641633
max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate
642634

635+
# the code below is taken from _get_bounding_box_corners_in_intrinsic_coordinates()
643636
# for triggering validation
644637
_ = BoundingBoxRequest(
645638
target_coordinate_system=target_coordinate_system,
@@ -648,100 +641,101 @@ def _(
648641
max_coordinate=max_coordinate,
649642
)
650643

651-
# get the four corners of the bounding box (2D case), or the 8 corners of the "3D bounding box" (3D case)
652-
(intrinsic_bounding_box_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates(
653-
element=points,
654-
axes=axes,
655-
min_coordinate=min_coordinate,
656-
max_coordinate=max_coordinate,
657-
target_coordinate_system=target_coordinate_system,
644+
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_transformation(
645+
points, target_coordinate_system
658646
)
659-
min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(dim="corner")
660-
max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(dim="corner")
661-
662-
min_coordinate_intrinsic = min_coordinate_intrinsic.data
663-
max_coordinate_intrinsic = max_coordinate_intrinsic.data
664-
665-
# get the points in the intrinsic coordinate bounding box
666-
in_intrinsic_bounding_box = _bounding_box_mask_points(
667-
points=points,
668-
axes=intrinsic_axes,
669-
min_coordinate=min_coordinate_intrinsic,
670-
max_coordinate=max_coordinate_intrinsic,
647+
m_without_c_linear = m_without_c[:-1, :-1]
648+
_ = _get_case_of_bounding_box_query(
649+
m_without_c_linear,
650+
input_axes_without_c,
651+
output_axes_without_c,
652+
)
653+
axes_adjusted, min_coordinate_adjusted, max_coordinate_adjusted = _adjust_bounding_box_to_real_axes(
654+
axes,
655+
min_coordinate,
656+
max_coordinate,
657+
output_axes_without_c,
658+
)
659+
if set(axes_adjusted) != set(output_axes_without_c):
660+
raise ValueError("The axes of the bounding box must match the axes of the transformation.")
661+
662+
# materialize the points in the intrinsic coordinate system once
663+
points_pd = points.compute()
664+
665+
# checking the type of the transformation
666+
# in the case of an identity or scaling transform, we can skip the whole
667+
# projection into intrinsic space and reprojection into the global coordinate system
668+
is_identity_transform = input_axes_without_c == output_axes_without_c and np.allclose(
669+
m_without_c, np.eye(m_without_c.shape[0])
671670
)
671+
is_scaling_transform = input_axes_without_c == output_axes_without_c and _is_scaling_transform(m_without_c_linear)
672+
673+
# if the transform is identity, we can save extra for the affine transformation
674+
if is_identity_transform:
675+
bounding_box_masks = _bounding_box_mask_points(
676+
points_df=points_pd,
677+
axes=axes_adjusted,
678+
min_coordinate=min_coordinate_adjusted,
679+
max_coordinate=max_coordinate_adjusted,
680+
)
681+
elif is_scaling_transform:
682+
# Pull scale factors from the diagonal and the translation from the last column
683+
scales = np.diagonal(m_without_c_linear) # shape: (n_axes,)
684+
translation = m_without_c[:-1, -1] # shape: (n_axes,)
685+
686+
# Invert the affine: x_intrinsic = (x_output - translation) / scale
687+
min_intrinsic = (min_coordinate_adjusted - translation) / scales
688+
max_intrinsic = (max_coordinate_adjusted - translation) / scales
689+
690+
# Negative scale components flip the interval; restore min < max.
691+
min_intrinsic, max_intrinsic = (
692+
np.minimum(min_intrinsic, max_intrinsic),
693+
np.maximum(min_intrinsic, max_intrinsic),
694+
)
672695

673-
if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)):
674-
raise ValueError(
675-
f"Length of list of dataframes `{len_df}` is not equal to the number of bounding boxes axes `{len_bb}`."
696+
bounding_box_masks = _bounding_box_mask_points(
697+
points_df=points_pd,
698+
axes=tuple(input_axes_without_c),
699+
min_coordinate=min_intrinsic,
700+
max_coordinate=max_intrinsic,
676701
)
677-
points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = []
678-
points_pd = points.compute()
679-
attrs = points.attrs.copy()
680-
for mask_np in in_intrinsic_bounding_box:
681-
if mask_np.sum() == 0:
682-
points_in_intrinsic_bounding_box.append(None)
683-
else:
684-
# TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now.
685-
# we can't compute either mask or points as when we calculate either one of them
686-
# test_query_points_multiple_partitions will fail as the mask will be used to index each partition.
687-
# However, if we compute and then create the dask array again we get the mixed dask graph problem.
688-
filtered_pd = points_pd[mask_np]
689-
points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions)
690-
points_filtered.attrs.update(attrs)
691-
points_in_intrinsic_bounding_box.append(points_filtered)
692-
if len(points_in_intrinsic_bounding_box) == 0:
693-
return None
702+
else:
703+
query_coordinates = points_pd.loc[:, list(input_axes_without_c)].to_numpy(copy=False)
704+
query_coordinates = query_coordinates @ m_without_c[:-1, :-1].T + m_without_c[:-1, -1]
705+
706+
bounding_box_masks = []
707+
for box_index in range(min_coordinate_adjusted.shape[0]):
708+
bounding_box_mask = np.ones(len(points_pd), dtype=bool)
709+
for axis_index in range(len(output_axes_without_c)):
710+
min_value = min_coordinate_adjusted[box_index, axis_index]
711+
max_value = max_coordinate_adjusted[box_index, axis_index]
712+
column = query_coordinates[:, axis_index]
713+
bounding_box_mask &= (column > min_value) & (column < max_value)
714+
bounding_box_masks.append(bounding_box_mask)
715+
716+
if not (len_df := len(bounding_box_masks)) == (len_bb := len(min_coordinate)):
717+
raise ValueError(f"Length of list of masks `{len_df}` is not equal to the number of bounding boxes `{len_bb}`.")
718+
719+
old_transformations = get_transformation(points, get_all=True)
720+
assert isinstance(old_transformations, dict)
721+
feature_key = points.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)
694722

695-
# assert that the number of queried points is correct
696-
assert len(points_in_intrinsic_bounding_box) == len(min_coordinate)
697-
698-
# # we have to reset the index since we have subset
699-
# # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask
700-
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.assign(idx=1)
701-
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.set_index(
702-
# points_in_intrinsic_bounding_box.idx.cumsum() - 1
703-
# )
704-
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.map_partitions(
705-
# lambda df: df.rename(index={"idx": None})
706-
# )
707-
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"])
708-
709-
# transform the element to the query coordinate system
710723
output: list[DaskDataFrame | None] = []
711-
for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate, strict=True):
712-
if p is None:
724+
for mask_np in bounding_box_masks:
725+
bounding_box_indices = np.flatnonzero(mask_np)
726+
if len(bounding_box_indices) == 0:
713727
output.append(None)
714-
else:
715-
points_query_coordinate_system = transform(
716-
p, to_coordinate_system=target_coordinate_system, maintain_positioning=False
728+
continue
729+
730+
# The exact mask is computed in the query coordinate system, but the returned points must stay intrinsic.
731+
queried_points = points_pd.iloc[bounding_box_indices]
732+
output.append(
733+
PointsModel.parse(
734+
dd.from_pandas(queried_points, npartitions=points.npartitions),
735+
transformations=old_transformations.copy(),
736+
feature_key=feature_key,
717737
)
718-
719-
# get a mask for the points in the bounding box
720-
bounding_box_mask = _bounding_box_mask_points(
721-
points=points_query_coordinate_system,
722-
axes=axes,
723-
min_coordinate=min_c, # type: ignore[arg-type]
724-
max_coordinate=max_c, # type: ignore[arg-type]
725-
)
726-
if len(bounding_box_mask) != 1:
727-
raise ValueError(f"Expected a single mask, got {len(bounding_box_mask)} masks. Please report this bug.")
728-
bounding_box_indices = np.where(bounding_box_mask[0])[0]
729-
730-
if len(bounding_box_indices) == 0:
731-
output.append(None)
732-
else:
733-
points_df = p.compute().iloc[bounding_box_indices]
734-
old_transformations = get_transformation(p, get_all=True)
735-
assert isinstance(old_transformations, dict)
736-
feature_key = p.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)
737-
738-
output.append(
739-
PointsModel.parse(
740-
dd.from_pandas(points_df, npartitions=1),
741-
transformations=old_transformations.copy(),
742-
feature_key=feature_key,
743-
)
744-
)
738+
)
745739
if len(output) == 0:
746740
return None
747741
if len(output) == 1:
@@ -791,8 +785,8 @@ def _(
791785
)
792786
for box_corners in intrinsic_bounding_box_corners:
793787
bounding_box_non_axes_aligned = Polygon(box_corners.data)
794-
indices = polygons.geometry.intersects(bounding_box_non_axes_aligned)
795-
queried = polygons[indices]
788+
candidate_idx = polygons.sindex.query(bounding_box_non_axes_aligned, predicate="intersects")
789+
queried = polygons.iloc[candidate_idx]
796790
if len(queried) == 0:
797791
queried_polygon = None
798792
else:
@@ -949,17 +943,22 @@ def _(
949943
assert np.all(element[OLD_INDEX] == buffered.index)
950944
else:
951945
buffered[OLD_INDEX] = buffered.index
952-
indices = buffered.geometry.apply(lambda x: x.intersects(polygon))
953-
if np.sum(indices) == 0:
946+
947+
# Use sindex for fast candidate pre-filtering, then exact intersection check
948+
# only on the (typically small) candidate set — same pattern as bounding_box_query.
949+
candidate_idx = buffered.sindex.query(polygon, predicate="intersects")
950+
if len(candidate_idx) == 0:
951+
del buffered[OLD_INDEX]
954952
return None
955-
queried_shapes = element[indices]
956-
queried_shapes.index = buffered[indices][OLD_INDEX]
953+
954+
queried_shapes = element.iloc[candidate_idx].copy()
955+
queried_shapes.index = buffered.iloc[candidate_idx][OLD_INDEX]
957956
queried_shapes.index.name = None
958957

959958
if clip:
960959
if isinstance(element.geometry.iloc[0], Point):
961-
queried_shapes = buffered[indices]
962-
queried_shapes.index = buffered[indices][OLD_INDEX]
960+
queried_shapes = buffered.iloc[candidate_idx]
961+
queried_shapes.index = buffered.iloc[candidate_idx][OLD_INDEX]
963962
queried_shapes.index.name = None
964963
queried_shapes = queried_shapes.clip(polygon_gdf, keep_geom_type=True)
965964

src/spatialdata/models/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def _(e: GeoDataFrame) -> tuple[str, ...]:
172172

173173

174174
@get_axes_names.register(DaskDataFrame)
175+
@get_axes_names.register(pd.DataFrame)
175176
def _(e: DaskDataFrame) -> tuple[str, ...]:
176177
valid_dims = (X, Y, Z)
177178
dims = tuple([c for c in valid_dims if c in e.columns])

0 commit comments

Comments
 (0)