Skip to content

Commit a0063cc

Browse files
Merge pull request #271 from scverse/fix/aggregation
Fix/aggregation
2 parents ae61666 + 1ab3361 commit a0063cc

10 files changed

Lines changed: 311 additions & 208 deletions

File tree

src/spatialdata/_core/operations/aggregate.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
ShapesModel,
2525
get_model,
2626
)
27-
from spatialdata.models._utils import get_axes_names
2827
from spatialdata.transformations import BaseTransformation, Identity, get_transformation
2928

3029
__all__ = ["aggregate"]
@@ -112,8 +111,8 @@ def _aggregate_points_by_shapes(
112111
value_key: str | None = None,
113112
agg_func: str | list[str] = "count",
114113
) -> ad.AnnData:
115-
# Have to get dims on dask dataframe, can't get from pandas
116-
dims = get_axes_names(points)
114+
from spatialdata.models import points_dask_dataframe_to_geopandas
115+
117116
# Default value for id_key
118117
if id_key is None:
119118
id_key = points.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY]
@@ -123,9 +122,8 @@ def _aggregate_points_by_shapes(
123122
"`FEATURE_KEY` for the points."
124123
)
125124

126-
if isinstance(points, ddf.DataFrame):
127-
points = points.compute()
128-
points = gpd.GeoDataFrame(points, geometry=gpd.points_from_xy(*[points[dim] for dim in dims]))
125+
points = points_dask_dataframe_to_geopandas(points, suppress_z_warning=True)
126+
shapes = circles_to_polygons(shapes)
129127

130128
return _aggregate_shapes(points, shapes, id_key, value_key, agg_func)
131129

@@ -253,8 +251,12 @@ def _aggregate_shapes(
253251
value_key: point_values,
254252
}
255253
)
254+
##
256255
aggregated = to_agg.groupby([by_id_key, id_key]).agg(agg_func).reset_index()
257-
obs_id_categorical = pd.Categorical(aggregated[by_id_key])
256+
257+
# this is for only shapes in "by" that intersect with something in "value"
258+
obs_id_categorical_categories = by.index.tolist()
259+
obs_id_categorical = pd.Categorical(aggregated[by_id_key], categories=obs_id_categorical_categories)
258260

259261
X = sparse.coo_matrix(
260262
(
@@ -265,7 +267,7 @@ def _aggregate_shapes(
265267
).tocsr()
266268
return ad.AnnData(
267269
X,
268-
obs=pd.DataFrame(index=obs_id_categorical.categories),
270+
obs=pd.DataFrame(index=pd.Categorical(obs_id_categorical_categories).categories),
269271
var=pd.DataFrame(index=joined[id_key].cat.categories),
270272
dtype=X.dtype,
271273
)

src/spatialdata/_core/query/spatial_query.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,18 @@ def _bounding_box_mask_points(
247247
def _dict_query_dispatcher(
248248
elements: dict[str, SpatialElement], query_function: Callable[[SpatialElement], SpatialElement], **kwargs: Any
249249
) -> dict[str, SpatialElement]:
250+
from spatialdata.transformations import get_transformation
251+
250252
queried_elements = {}
251253
for key, element in elements.items():
252-
result = query_function(element, **kwargs)
253-
if result is not None:
254-
# query returns None if it is empty
255-
queried_elements[key] = result
254+
target_coordinate_system = kwargs["target_coordinate_system"]
255+
d = get_transformation(element, get_all=True)
256+
assert isinstance(d, dict)
257+
if target_coordinate_system in d:
258+
result = query_function(element, **kwargs)
259+
if result is not None:
260+
# query returns None if it is empty
261+
queried_elements[key] = result
256262
return queried_elements
257263

258264

@@ -649,12 +655,12 @@ def _polygon_query(
649655
new_points = {}
650656
if points:
651657
for points_name, p in sdata.points.items():
652-
points_gdf = points_dask_dataframe_to_geopandas(p)
658+
points_gdf = points_dask_dataframe_to_geopandas(p, suppress_z_warning=True)
653659
indices = points_gdf.geometry.intersects(polygon)
654660
if np.sum(indices) == 0:
655661
raise ValueError("we expect at least one point")
656662
queried_points = points_gdf[indices]
657-
ddf = points_geopandas_to_dask_dataframe(queried_points)
663+
ddf = points_geopandas_to_dask_dataframe(queried_points, suppress_z_warning=True)
658664
transformation = get_transformation(p, target_coordinate_system)
659665
if "z" in ddf.columns:
660666
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})

src/spatialdata/_io/io_points.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ def write_points(
5454

5555
points_groups = group.require_group(name)
5656
path = Path(points_groups._store.path) / points_groups.path / "points.parquet"
57+
58+
# The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is
59+
# 'category', it checks whether the categories of this column are known. If not, it explicitly converts the
60+
# categories to known categories using 'c.cat.as_known()' and assigns the transformed Series back to the original
61+
# DataFrame. This step is crucial when the number of categories exceeds 127, as pyarrow defaults to int8 for
62+
# unknown categories which can only hold values from -128 to 127.
63+
for column_name in points.columns:
64+
c = points[column_name]
65+
if c.dtype == "category" and not c.cat.known:
66+
c = c.cat.as_known()
67+
points[column_name] = c
68+
5769
points.to_parquet(path)
5870

5971
attrs = fmt.attrs_to_dict(points.attrs)

src/spatialdata/models/_utils.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _validate_dims(dims: tuple[str, ...]) -> None:
190190
raise ValueError(f"Invalid dimensions: {dims}")
191191

192192

193-
def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame:
193+
def points_dask_dataframe_to_geopandas(points: DaskDataFrame, suppress_z_warning: bool = False) -> GeoDataFrame:
194194
"""
195195
Convert a Dask DataFrame to a GeoDataFrame.
196196
@@ -212,16 +212,25 @@ def points_dask_dataframe_to_geopandas(points: DaskDataFrame) -> GeoDataFrame:
212212
points need to be saved as a Dask DataFrame. We will be restructuring the models to allow for GeoDataFrames soon.
213213
214214
"""
215-
if "z" in points.columns:
215+
from spatialdata.transformations import get_transformation, set_transformation
216+
217+
if "z" in points.columns and not suppress_z_warning:
216218
logger.warning("Constructing the GeoDataFrame without considering the z coordinate in the geometry.")
217219

218-
points_gdf = GeoDataFrame(geometry=geopandas.points_from_xy(points["x"], points["y"]))
219-
for c in points.columns:
220-
points_gdf[c] = points[c]
220+
transformations = get_transformation(points, get_all=True)
221+
assert isinstance(transformations, dict)
222+
assert len(transformations) > 0
223+
points = points.compute()
224+
points_gdf = GeoDataFrame(points, geometry=geopandas.points_from_xy(points["x"], points["y"]))
225+
points_gdf.reset_index(drop=True, inplace=True)
226+
# keep the x and y either in the geometry either as columns: we don't duplicate because having this redundancy could
227+
# lead to subtle bugs when coverting back to dask dataframes
228+
points_gdf.drop(columns=["x", "y"], inplace=True)
229+
set_transformation(points_gdf, transformations, set_all=True)
221230
return points_gdf
222231

223232

224-
def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame:
233+
def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame, suppress_z_warning: bool = False) -> DaskDataFrame:
225234
"""
226235
Convert a GeoDataFrame which represents 2D or 3D points to a Dask DataFrame that passes the schema validation.
227236
@@ -241,15 +250,20 @@ def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame) -> DaskDataFrame:
241250
"""
242251
from spatialdata.models import PointsModel
243252

253+
# transformations are transferred automatically
244254
ddf = dd.from_pandas(gdf[gdf.columns.drop("geometry")], npartitions=1)
255+
# we don't want redundancy in the columns since this could lead to subtle bugs when converting back to geopandas
256+
assert "x" not in ddf.columns
257+
assert "y" not in ddf.columns
245258
ddf["x"] = gdf.geometry.x
246259
ddf["y"] = gdf.geometry.y
247260
# parse
248261
if "z" in ddf.columns:
249-
logger.warning(
250-
"Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an "
251-
"additional column."
252-
)
262+
if not suppress_z_warning:
263+
logger.warning(
264+
"Constructing the Dask DataFrame using the x and y coordinates from the geometry and the z from an "
265+
"additional column."
266+
)
253267
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y", "z": "z"})
254268
else:
255269
ddf = PointsModel.parse(ddf, coordinates={"x": "x", "y": "y"})

tests/conftest.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
os.environ["USE_PYGEOS"] = "0"
55
# isort:on
66

7+
from shapely import linearrings, polygons
78
from pathlib import Path
89
from typing import Union
910
from spatialdata._types import ArrayLike
@@ -29,6 +30,7 @@
2930
)
3031
from xarray import DataArray
3132
from spatialdata.datasets import BlobsDataset
33+
import geopandas as gpd
3234

3335
RNG = default_rng()
3436

@@ -249,10 +251,16 @@ def _get_points() -> dict[str, DaskDataFrame]:
249251
out = {}
250252
for i in range(2):
251253
name = f"{name}_{i}"
252-
arr = RNG.normal(size=(100, 2))
254+
arr = RNG.normal(size=(300, 2))
253255
# randomly assign some values from v to the points
254256
points_assignment0 = RNG.integers(0, 10, size=arr.shape[0]).astype(np.int_)
255-
genes = RNG.choice(["a", "b"], size=arr.shape[0])
257+
if i == 0:
258+
genes = RNG.choice(["a", "b"], size=arr.shape[0])
259+
else:
260+
# we need to test the case in which we have a categorical column with more than 127 categories, see full
261+
# explanation in write_points() (the parser will convert this column to a categorical since
262+
# feature_key="genes")
263+
genes = np.tile(np.array(list(map(str, range(280)))), 2)[:300]
256264
annotation = pd.DataFrame(
257265
{
258266
"genes": genes,
@@ -299,3 +307,114 @@ def sdata_blobs() -> SpatialData:
299307
sdata.labels["blobs_multiscale_labels"]
300308
)
301309
return sdata
310+
311+
312+
def _make_points(coordinates: np.ndarray) -> DaskDataFrame:
313+
"""Helper function to make a Points element."""
314+
k0 = int(len(coordinates) / 3)
315+
k1 = len(coordinates) - k0
316+
genes = np.hstack((np.repeat("a", k0), np.repeat("b", k1)))
317+
return PointsModel.parse(coordinates, annotation=pd.DataFrame({"genes": genes}), feature_key="genes")
318+
319+
320+
def _make_squares(centroid_coordinates: np.ndarray, half_widths: list[float]) -> polygons:
321+
linear_rings = []
322+
for centroid, half_width in zip(centroid_coordinates, half_widths):
323+
min_coords = centroid - half_width
324+
max_coords = centroid + half_width
325+
326+
linear_rings.append(
327+
linearrings(
328+
[
329+
[min_coords[0], min_coords[1]],
330+
[min_coords[0], max_coords[1]],
331+
[max_coords[0], max_coords[1]],
332+
[max_coords[0], min_coords[1]],
333+
]
334+
)
335+
)
336+
s = polygons(linear_rings)
337+
polygon_series = gpd.GeoSeries(s)
338+
cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series)
339+
return ShapesModel.parse(cell_polygon_table)
340+
341+
342+
def _make_circles(centroid_coordinates: np.ndarray, radius: list[float]) -> GeoDataFrame:
343+
return ShapesModel.parse(centroid_coordinates, geometry=0, radius=radius)
344+
345+
346+
def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData:
347+
"""
348+
Creates a SpatialData object with many edge cases for testing querying and aggregation.
349+
350+
Returns
351+
-------
352+
The SpatialData object.
353+
354+
Notes
355+
-----
356+
Description of what is tested (for a quick visualization, plot the returned SpatialData object):
357+
- values to query/aggregate: polygons, points, circles
358+
- values to query by: polygons, circles
359+
- the shapes are completely inside, outside, or intersecting the query region (with the centroid inside or outside
360+
the query region)
361+
362+
Additional cases:
363+
- concave shape intersecting multiple times the same shape; used both as query and as value
364+
- shape intersecting multiple shapes; used both as query and as value
365+
"""
366+
values_centroids_squares = np.array([[x * 18, 0] for x in range(8)] + [[8 * 18 + 7, 0]] + [[0, 90], [50, 90]])
367+
values_centroids_circles = np.array([[x * 18, 30] for x in range(8)] + [[8 * 18 + 7, 30]])
368+
by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90], [210, 15]])
369+
by_centroids_circles = np.array([[24, 15], [290, 15]])
370+
values_points = _make_points(np.vstack((values_centroids_squares, values_centroids_circles)))
371+
values_squares = _make_squares(values_centroids_squares, half_widths=[6] * 9 + [15, 15])
372+
values_circles = _make_circles(values_centroids_circles, radius=[6] * 9)
373+
by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15, 30])
374+
by_circles = _make_circles(by_centroids_circles, radius=[30, 30])
375+
376+
from shapely.geometry import Polygon
377+
378+
polygon = Polygon([(100, 90 - 10), (100 + 30, 90), (100, 90 + 10), (150, 90)])
379+
values_squares.loc[len(values_squares)] = [polygon]
380+
ShapesModel.validate(values_squares)
381+
382+
polygon = Polygon([(0, 90 - 10), (0 + 30, 90), (0, 90 + 10), (50, 90)])
383+
by_squares.loc[len(by_squares)] = [polygon]
384+
ShapesModel.validate(by_squares)
385+
386+
sdata = SpatialData(
387+
points={"points": values_points},
388+
shapes={
389+
"values_polygons": values_squares,
390+
"values_circles": values_circles,
391+
"by_polygons": by_squares,
392+
"by_circles": by_circles,
393+
},
394+
)
395+
# to visualize the cases considered in the test, much more immediate than reading them as text as done above
396+
PLOT = False
397+
if PLOT:
398+
##
399+
import matplotlib.pyplot as plt
400+
401+
ax = plt.gca()
402+
sdata.pl.render_shapes(element="values_polygons", na_color=(0.5, 0.2, 0.5, 0.5)).pl.render_points().pl.show(
403+
ax=ax
404+
)
405+
sdata.pl.render_shapes(element="values_circles", na_color=(0.5, 0.2, 0.5, 0.5)).pl.show(ax=ax)
406+
sdata.pl.render_shapes(element="by_polygons", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax)
407+
sdata.pl.render_shapes(element="by_circles", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax)
408+
plt.show()
409+
##
410+
411+
# generate table
412+
x = np.ones((21, 2)) * np.array([1, 2])
413+
region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12)
414+
instance_id = np.array(list(range(9)) + list(range(12)))
415+
table = AnnData(x, obs=pd.DataFrame({"region": region, "instance_id": instance_id}))
416+
table = TableModel.parse(
417+
table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id"
418+
)
419+
sdata.table = table
420+
return sdata

0 commit comments

Comments
 (0)