Skip to content

Commit 2e55e83

Browse files
Merge pull request #282 from melonora/fix_blobs
Fix blobs
2 parents fed4de4 + 7ba4a23 commit 2e55e83

2 files changed

Lines changed: 63 additions & 27 deletions

File tree

src/spatialdata/datasets.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""SpatialData datasets."""
2-
from typing import Optional, Union
2+
from typing import Any, Optional, Union
33

44
import numpy as np
55
import pandas as pd
@@ -22,11 +22,14 @@
2222
ShapesModel,
2323
TableModel,
2424
)
25+
from spatialdata.transformations import Identity
2526

2627
__all__ = ["blobs", "raccoon"]
2728

2829

29-
def blobs(length: int = 512, n_points: int = 200, n_shapes: int = 5) -> SpatialData:
30+
def blobs(
31+
length: int = 512, n_points: int = 200, n_shapes: int = 5, extra_coord_system: Optional[str] = None
32+
) -> SpatialData:
3033
"""
3134
Blobs dataset.
3235
@@ -39,13 +42,18 @@ def blobs(length: int = 512, n_points: int = 200, n_shapes: int = 5) -> SpatialD
3942
n_shapes
4043
Number of max shapes to generate.
4144
At most, as if overlapping they will be discarded
45+
extra_coord_system
46+
Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
47+
4248
4349
Returns
4450
-------
4551
SpatialData
4652
SpatialData object with blobs dataset.
4753
"""
48-
return BlobsDataset(length=length, n_points=n_points, n_shapes=n_shapes).blobs()
54+
return BlobsDataset(
55+
length=length, n_points=n_points, n_shapes=n_shapes, extra_coord_system=extra_coord_system
56+
).blobs()
4957

5058

5159
def raccoon() -> SpatialData:
@@ -75,7 +83,9 @@ def raccoon(
7583
class BlobsDataset:
7684
"""Blobs dataset."""
7785

78-
def __init__(self, length: int = 512, n_points: int = 200, n_shapes: int = 5) -> None:
86+
def __init__(
87+
self, length: int = 512, n_points: int = 200, n_shapes: int = 5, extra_coord_system: Optional[str] = None
88+
) -> None:
7989
"""
8090
Blobs dataset.
8191
@@ -88,23 +98,28 @@ def __init__(self, length: int = 512, n_points: int = 200, n_shapes: int = 5) ->
8898
n_shapes
8999
Number of max shapes to generate.
90100
At most, as if overlapping they will be discarded
101+
extra_coord_system
102+
Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
91103
"""
92104
self.length = length
93105
self.n_points = n_points
94106
self.n_shapes = n_shapes
107+
self.transformations = {"global": Identity()}
108+
if extra_coord_system:
109+
self.transformations[extra_coord_system] = Identity()
95110

96111
def blobs(
97112
self,
98113
) -> SpatialData:
99114
"""Blobs dataset."""
100-
image = self._image_blobs(self.length)
101-
multiscale_image = self._image_blobs(self.length, multiscale=True)
102-
labels = self._labels_blobs(self.length)
103-
multiscale_labels = self._labels_blobs(self.length, multiscale=True)
104-
points = self._points_blobs(self.length, self.n_points)
105-
circles = self._circles_blobs(self.length, self.n_shapes)
106-
polygons = self._polygons_blobs(self.length, self.n_shapes)
107-
multipolygons = self._polygons_blobs(self.length, self.n_shapes, multipolygons=True)
115+
image = self._image_blobs(self.transformations, self.length)
116+
multiscale_image = self._image_blobs(self.transformations, self.length, multiscale=True)
117+
labels = self._labels_blobs(self.transformations, self.length)
118+
multiscale_labels = self._labels_blobs(self.transformations, self.length, multiscale=True)
119+
points = self._points_blobs(self.transformations, self.length, self.n_points)
120+
circles = self._circles_blobs(self.transformations, self.length, self.n_shapes)
121+
polygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes)
122+
multipolygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes, multipolygons=True)
108123
adata = aggregate(image, labels)
109124
adata.obs["region"] = pd.Categorical(["blobs_labels"] * len(adata))
110125
adata.obs["instance_id"] = adata.obs_names.astype(int)
@@ -118,7 +133,12 @@ def blobs(
118133
table=table,
119134
)
120135

121-
def _image_blobs(self, length: int = 512, multiscale: bool = False) -> Union[SpatialImage, MultiscaleSpatialImage]:
136+
def _image_blobs(
137+
self,
138+
transformations: Optional[dict[str, Any]] = None,
139+
length: int = 512,
140+
multiscale: bool = False,
141+
) -> Union[SpatialImage, MultiscaleSpatialImage]:
122142
masks = []
123143
for i in range(3):
124144
mask = self._generate_blobs(length=length, seed=i)
@@ -128,10 +148,12 @@ def _image_blobs(self, length: int = 512, multiscale: bool = False) -> Union[Spa
128148
x = np.stack(masks, axis=0)
129149
dims = ["c", "y", "x"]
130150
if not multiscale:
131-
return Image2DModel.parse(x, dims=dims)
132-
return Image2DModel.parse(x, dims=dims, scale_factors=[2, 2])
151+
return Image2DModel.parse(x, transformations=transformations, dims=dims)
152+
return Image2DModel.parse(x, transformations=transformations, dims=dims, scale_factors=[2, 2])
133153

134-
def _labels_blobs(self, length: int = 512, multiscale: bool = False) -> Union[SpatialImage, MultiscaleSpatialImage]:
154+
def _labels_blobs(
155+
self, transformations: Optional[dict[str, Any]] = None, length: int = 512, multiscale: bool = False
156+
) -> Union[SpatialImage, MultiscaleSpatialImage]:
135157
"""Create a 2D labels."""
136158
from scipy.ndimage import watershed_ift
137159

@@ -155,8 +177,8 @@ def _labels_blobs(self, length: int = 512, multiscale: bool = False) -> Union[Sp
155177
out[out == val[idx]] = i
156178
dims = ["y", "x"]
157179
if not multiscale:
158-
return Labels2DModel.parse(out, dims=dims)
159-
return Labels2DModel.parse(out, dims=dims, scale_factors=[2, 2])
180+
return Labels2DModel.parse(out, transformations=transformations, dims=dims)
181+
return Labels2DModel.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2])
160182

161183
def _generate_blobs(self, length: int = 512, seed: Optional[int] = None) -> ArrayLike:
162184
from scipy.ndimage import gaussian_filter
@@ -171,21 +193,27 @@ def _generate_blobs(self, length: int = 512, seed: Optional[int] = None) -> Arra
171193
mask = gaussian_filter(mask, sigma=0.25 * length * 0.1)
172194
return mask
173195

174-
def _points_blobs(self, length: int = 512, n_points: int = 200) -> DaskDataFrame:
196+
def _points_blobs(
197+
self, transformations: Optional[dict[str, Any]] = None, length: int = 512, n_points: int = 200
198+
) -> DaskDataFrame:
175199
rng = default_rng(42)
176-
arr = rng.integers(10, length - 10, size=(n_points, 2)).astype(np.int_)
200+
arr = rng.integers(10, length - 10, size=(n_points, 2)).astype(np.int64)
177201
# randomly assign some values from v to the points
178-
points_assignment0 = rng.integers(0, 10, size=arr.shape[0]).astype(np.int_)
202+
points_assignment0 = rng.integers(0, 10, size=arr.shape[0]).astype(np.int64)
179203
genes = rng.choice(["a", "b"], size=arr.shape[0])
180204
annotation = pd.DataFrame(
181205
{
182206
"genes": genes,
183207
"instance_id": points_assignment0,
184208
},
185209
)
186-
return PointsModel.parse(arr, annotation=annotation, feature_key="genes", instance_key="instance_id")
210+
return PointsModel.parse(
211+
arr, transformations=transformations, annotation=annotation, feature_key="genes", instance_key="instance_id"
212+
)
187213

188-
def _circles_blobs(self, length: int = 512, n_shapes: int = 5) -> GeoDataFrame:
214+
def _circles_blobs(
215+
self, transformations: Optional[dict[str, Any]] = None, length: int = 512, n_shapes: int = 5
216+
) -> GeoDataFrame:
189217
midpoint = length // 2
190218
halfmidpoint = midpoint // 2
191219
radius = length // 10
@@ -195,9 +223,15 @@ def _circles_blobs(self, length: int = 512, n_shapes: int = 5) -> GeoDataFrame:
195223
"radius": radius,
196224
}
197225
)
198-
return ShapesModel.parse(circles)
226+
return ShapesModel.parse(circles, transformations=transformations)
199227

200-
def _polygons_blobs(self, length: int = 512, n_shapes: int = 5, multipolygons: bool = False) -> GeoDataFrame:
228+
def _polygons_blobs(
229+
self,
230+
transformations: Optional[dict[str, Any]] = None,
231+
length: int = 512,
232+
n_shapes: int = 5,
233+
multipolygons: bool = False,
234+
) -> GeoDataFrame:
201235
midpoint = length // 2
202236
halfmidpoint = midpoint // 2
203237
poly = GeoDataFrame(
@@ -207,7 +241,7 @@ def _polygons_blobs(self, length: int = 512, n_shapes: int = 5, multipolygons: b
207241
)
208242
}
209243
)
210-
return ShapesModel.parse(poly)
244+
return ShapesModel.parse(poly, transformations=transformations)
211245

212246
# function that generates random shapely polygons given a bounding box
213247
def _generate_random_polygons(

tests/datasets/test_datasets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33

44
def test_datasets() -> None:
5-
sdata_blobs = blobs()
5+
extra_cs = "test"
6+
sdata_blobs = blobs(extra_coord_system=extra_cs)
67

78
assert len(sdata_blobs.table) == 26
89
assert len(sdata_blobs.shapes["blobs_circles"]) == 5
@@ -13,6 +14,7 @@ def test_datasets() -> None:
1314
assert len(sdata_blobs.images["blobs_multiscale_image"]) == 3
1415
assert sdata_blobs.labels["blobs_labels"].shape == (512, 512)
1516
assert len(sdata_blobs.labels["blobs_multiscale_labels"]) == 3
17+
assert extra_cs in sdata_blobs.coordinate_systems
1618
# this catches this bug: https://github.com/scverse/spatialdata/issues/269
1719
_ = str(sdata_blobs)
1820

0 commit comments

Comments
 (0)