|
4 | 4 | os.environ["USE_PYGEOS"] = "0" |
5 | 5 | # isort:on |
6 | 6 |
|
| 7 | +from shapely import linearrings, polygons |
7 | 8 | from pathlib import Path |
8 | 9 | from typing import Union |
9 | 10 | from spatialdata._types import ArrayLike |
|
29 | 30 | ) |
30 | 31 | from xarray import DataArray |
31 | 32 | from spatialdata.datasets import BlobsDataset |
| 33 | +import geopandas as gpd |
32 | 34 |
|
33 | 35 | RNG = default_rng() |
34 | 36 |
|
@@ -249,10 +251,16 @@ def _get_points() -> dict[str, DaskDataFrame]: |
249 | 251 | out = {} |
250 | 252 | for i in range(2): |
251 | 253 | name = f"{name}_{i}" |
252 | | - arr = RNG.normal(size=(100, 2)) |
| 254 | + arr = RNG.normal(size=(300, 2)) |
253 | 255 | # randomly assign some values from v to the points |
254 | 256 | points_assignment0 = RNG.integers(0, 10, size=arr.shape[0]).astype(np.int_) |
255 | | - genes = RNG.choice(["a", "b"], size=arr.shape[0]) |
| 257 | + if i == 0: |
| 258 | + genes = RNG.choice(["a", "b"], size=arr.shape[0]) |
| 259 | + else: |
| 260 | + # we need to test the case in which we have a categorical column with more than 127 categories, see full |
| 261 | + # explanation in write_points() (the parser will convert this column to a categorical since |
| 262 | + # feature_key="genes") |
| 263 | + genes = np.tile(np.array(list(map(str, range(280)))), 2)[:300] |
256 | 264 | annotation = pd.DataFrame( |
257 | 265 | { |
258 | 266 | "genes": genes, |
@@ -299,3 +307,114 @@ def sdata_blobs() -> SpatialData: |
299 | 307 | sdata.labels["blobs_multiscale_labels"] |
300 | 308 | ) |
301 | 309 | return sdata |
| 310 | + |
| 311 | + |
| 312 | +def _make_points(coordinates: np.ndarray) -> DaskDataFrame: |
| 313 | + """Helper function to make a Points element.""" |
| 314 | + k0 = int(len(coordinates) / 3) |
| 315 | + k1 = len(coordinates) - k0 |
| 316 | + genes = np.hstack((np.repeat("a", k0), np.repeat("b", k1))) |
| 317 | + return PointsModel.parse(coordinates, annotation=pd.DataFrame({"genes": genes}), feature_key="genes") |
| 318 | + |
| 319 | + |
| 320 | +def _make_squares(centroid_coordinates: np.ndarray, half_widths: list[float]) -> polygons: |
| 321 | + linear_rings = [] |
| 322 | + for centroid, half_width in zip(centroid_coordinates, half_widths): |
| 323 | + min_coords = centroid - half_width |
| 324 | + max_coords = centroid + half_width |
| 325 | + |
| 326 | + linear_rings.append( |
| 327 | + linearrings( |
| 328 | + [ |
| 329 | + [min_coords[0], min_coords[1]], |
| 330 | + [min_coords[0], max_coords[1]], |
| 331 | + [max_coords[0], max_coords[1]], |
| 332 | + [max_coords[0], min_coords[1]], |
| 333 | + ] |
| 334 | + ) |
| 335 | + ) |
| 336 | + s = polygons(linear_rings) |
| 337 | + polygon_series = gpd.GeoSeries(s) |
| 338 | + cell_polygon_table = gpd.GeoDataFrame(geometry=polygon_series) |
| 339 | + return ShapesModel.parse(cell_polygon_table) |
| 340 | + |
| 341 | + |
| 342 | +def _make_circles(centroid_coordinates: np.ndarray, radius: list[float]) -> GeoDataFrame: |
| 343 | + return ShapesModel.parse(centroid_coordinates, geometry=0, radius=radius) |
| 344 | + |
| 345 | + |
| 346 | +def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: |
| 347 | + """ |
| 348 | + Creates a SpatialData object with many edge cases for testing querying and aggregation. |
| 349 | +
|
| 350 | + Returns |
| 351 | + ------- |
| 352 | + The SpatialData object. |
| 353 | +
|
| 354 | + Notes |
| 355 | + ----- |
| 356 | + Description of what is tested (for a quick visualization, plot the returned SpatialData object): |
| 357 | + - values to query/aggregate: polygons, points, circles |
| 358 | + - values to query by: polygons, circles |
| 359 | + - the shapes are completely inside, outside, or intersecting the query region (with the centroid inside or outside |
| 360 | + the query region) |
| 361 | +
|
| 362 | + Additional cases: |
| 363 | + - concave shape intersecting multiple times the same shape; used both as query and as value |
| 364 | + - shape intersecting multiple shapes; used both as query and as value |
| 365 | + """ |
| 366 | + values_centroids_squares = np.array([[x * 18, 0] for x in range(8)] + [[8 * 18 + 7, 0]] + [[0, 90], [50, 90]]) |
| 367 | + values_centroids_circles = np.array([[x * 18, 30] for x in range(8)] + [[8 * 18 + 7, 30]]) |
| 368 | + by_centroids_squares = np.array([[119, 15], [100, 90], [150, 90], [210, 15]]) |
| 369 | + by_centroids_circles = np.array([[24, 15], [290, 15]]) |
| 370 | + values_points = _make_points(np.vstack((values_centroids_squares, values_centroids_circles))) |
| 371 | + values_squares = _make_squares(values_centroids_squares, half_widths=[6] * 9 + [15, 15]) |
| 372 | + values_circles = _make_circles(values_centroids_circles, radius=[6] * 9) |
| 373 | + by_squares = _make_squares(by_centroids_squares, half_widths=[30, 15, 15, 30]) |
| 374 | + by_circles = _make_circles(by_centroids_circles, radius=[30, 30]) |
| 375 | + |
| 376 | + from shapely.geometry import Polygon |
| 377 | + |
| 378 | + polygon = Polygon([(100, 90 - 10), (100 + 30, 90), (100, 90 + 10), (150, 90)]) |
| 379 | + values_squares.loc[len(values_squares)] = [polygon] |
| 380 | + ShapesModel.validate(values_squares) |
| 381 | + |
| 382 | + polygon = Polygon([(0, 90 - 10), (0 + 30, 90), (0, 90 + 10), (50, 90)]) |
| 383 | + by_squares.loc[len(by_squares)] = [polygon] |
| 384 | + ShapesModel.validate(by_squares) |
| 385 | + |
| 386 | + sdata = SpatialData( |
| 387 | + points={"points": values_points}, |
| 388 | + shapes={ |
| 389 | + "values_polygons": values_squares, |
| 390 | + "values_circles": values_circles, |
| 391 | + "by_polygons": by_squares, |
| 392 | + "by_circles": by_circles, |
| 393 | + }, |
| 394 | + ) |
| 395 | + # to visualize the cases considered in the test, much more immediate than reading them as text as done above |
| 396 | + PLOT = False |
| 397 | + if PLOT: |
| 398 | + ## |
| 399 | + import matplotlib.pyplot as plt |
| 400 | + |
| 401 | + ax = plt.gca() |
| 402 | + sdata.pl.render_shapes(element="values_polygons", na_color=(0.5, 0.2, 0.5, 0.5)).pl.render_points().pl.show( |
| 403 | + ax=ax |
| 404 | + ) |
| 405 | + sdata.pl.render_shapes(element="values_circles", na_color=(0.5, 0.2, 0.5, 0.5)).pl.show(ax=ax) |
| 406 | + sdata.pl.render_shapes(element="by_polygons", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) |
| 407 | + sdata.pl.render_shapes(element="by_circles", na_color=(1.0, 0.7, 0.7, 0.5)).pl.show(ax=ax) |
| 408 | + plt.show() |
| 409 | + ## |
| 410 | + |
| 411 | + # generate table |
| 412 | + x = np.ones((21, 2)) * np.array([1, 2]) |
| 413 | + region = np.array(["values_circles"] * 9 + ["values_polygons"] * 12) |
| 414 | + instance_id = np.array(list(range(9)) + list(range(12))) |
| 415 | + table = AnnData(x, obs=pd.DataFrame({"region": region, "instance_id": instance_id})) |
| 416 | + table = TableModel.parse( |
| 417 | + table, region=["values_circles", "values_polygons"], region_key="region", instance_key="instance_id" |
| 418 | + ) |
| 419 | + sdata.table = table |
| 420 | + return sdata |
0 commit comments