Skip to content

Commit 93615b2

Browse files
Preserve points feature_key in queries (#794)
* Preserve points feature_key during queries * add PR number to changelog * fix docs; add sphinx-autobuild dep --------- Co-authored-by: Luca Marconato <[email protected]>
1 parent 94f0a31 commit 93615b2

File tree

6 files changed

+33
-7
lines changed

6 files changed

+33
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning][].
2525
### Fixed
2626

2727
- Updated deprecated default stages of `pre-commit` #771
28+
- Preserve points `feature_key` during queries #794
2829

2930
## [0.2.5] - 2024-06-11
3031

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ dev = [
5555
]
5656
docs = [
5757
"sphinx>=4.5",
58+
"sphinx-autobuild",
5859
"sphinx-book-theme>=1.0.0",
5960
"myst-nb",
6061
"sphinxcontrib-bibtex>=1.0.0",

src/spatialdata/_core/query/spatial_query.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
points_geopandas_to_dask_dataframe,
3434
)
3535
from spatialdata.models._utils import ValidAxis_t, get_spatial_axes
36+
from spatialdata.models.models import ATTRS_KEY
3637
from spatialdata.transformations.operations import set_transformation
3738
from spatialdata.transformations.transformations import (
3839
Affine,
@@ -712,9 +713,13 @@ def _(
712713
points_df = p.compute().iloc[bounding_box_indices]
713714
old_transformations = get_transformation(p, get_all=True)
714715
assert isinstance(old_transformations, dict)
716+
feature_key = p.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)
717+
715718
output.append(
716719
PointsModel.parse(
717-
dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy()
720+
dd.from_pandas(points_df, npartitions=1),
721+
transformations=old_transformations.copy(),
722+
feature_key=feature_key,
718723
)
719724
)
720725
if len(output) == 0:
@@ -925,10 +930,11 @@ def _(
925930
queried_points = points_gdf.loc[joined["index_right"]]
926931
ddf = points_geopandas_to_dask_dataframe(queried_points, suppress_z_warning=True)
927932
transformation = get_transformation(points, target_coordinate_system)
933+
feature_key = points.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)
928934
if "z" in ddf.columns:
929-
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})
935+
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"}, feature_key=feature_key)
930936
else:
931-
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"})
937+
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"}, feature_key=feature_key)
932938
set_transformation(ddf, transformation, target_coordinate_system)
933939
t = get_transformation(ddf, get_all=True)
934940
assert isinstance(t, dict)

src/spatialdata/models/_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,7 @@ def set_channel_names(element: DataArray | DataTree, channel_names: str | list[s
417417
418418
Returns
419419
-------
420-
element
421-
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
420+
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
422421
"""
423422
from spatialdata.models import Image2DModel, Image3DModel, get_model
424423

src/spatialdata/models/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,11 @@ def _(
736736
elif isinstance(data, dd.DataFrame): # type: ignore[attr-defined]
737737
table = data[[coordinates[ax] for ax in axes]]
738738
table.columns = axes
739-
if feature_key is not None and data[feature_key].dtype.name != "category":
740-
table[feature_key] = data[feature_key].astype(str).astype("category")
739+
if feature_key is not None:
740+
if data[feature_key].dtype.name == "category":
741+
table[feature_key] = data[feature_key]
742+
else:
743+
table[feature_key] = data[feature_key].astype(str).astype("category")
741744
if instance_key is not None:
742745
table[instance_key] = data[instance_key]
743746
for c in [X, Y, Z]:

tests/core/query/test_spatial_query.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ShapesModel,
3030
TableModel,
3131
)
32+
from spatialdata.models.models import ATTRS_KEY
3233
from spatialdata.testing import assert_spatial_data_objects_are_identical
3334
from spatialdata.transformations import Identity, MapAxis, set_transformation
3435
from tests.conftest import _make_points, _make_squares
@@ -205,6 +206,21 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, mul
205206
if is_3d:
206207
np.testing.assert_allclose(points_element["z"].compute(), original_z)
207208

209+
# the feature_key should be preserved
210+
if not multiple_boxes:
211+
assert (
212+
points_result.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
213+
== points_element.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
214+
)
215+
else:
216+
for result in points_result:
217+
if result is None:
218+
continue
219+
assert (
220+
result.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
221+
== points_element.attrs[ATTRS_KEY][PointsModel.FEATURE_KEY]
222+
)
223+
208224

209225
def test_query_points_no_points():
210226
"""Points bounding box query with no points in range should

0 commit comments

Comments
 (0)