11"""SpatialData datasets."""
2- from typing import Optional , Union
2+ from typing import Any , Optional , Union
33
44import numpy as np
55import pandas as pd
2222 ShapesModel ,
2323 TableModel ,
2424)
25+ from spatialdata .transformations import Identity
2526
2627__all__ = ["blobs" , "raccoon" ]
2728
2829
29- def blobs (length : int = 512 , n_points : int = 200 , n_shapes : int = 5 ) -> SpatialData :
30+ def blobs (
31+ length : int = 512 , n_points : int = 200 , n_shapes : int = 5 , extra_coord_system : Optional [str ] = None
32+ ) -> SpatialData :
3033 """
3134 Blobs dataset.
3235
@@ -39,13 +42,18 @@ def blobs(length: int = 512, n_points: int = 200, n_shapes: int = 5) -> SpatialD
3942 n_shapes
4043 Number of max shapes to generate.
4144 At most, as if overlapping they will be discarded
45+ extra_coord_system
46+ Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
47+
4248
4349 Returns
4450 -------
4551 SpatialData
4652 SpatialData object with blobs dataset.
4753 """
48- return BlobsDataset (length = length , n_points = n_points , n_shapes = n_shapes ).blobs ()
54+ return BlobsDataset (
55+ length = length , n_points = n_points , n_shapes = n_shapes , extra_coord_system = extra_coord_system
56+ ).blobs ()
4957
5058
5159def raccoon () -> SpatialData :
@@ -75,7 +83,9 @@ def raccoon(
7583class BlobsDataset :
7684 """Blobs dataset."""
7785
78- def __init__ (self , length : int = 512 , n_points : int = 200 , n_shapes : int = 5 ) -> None :
86+ def __init__ (
87+ self , length : int = 512 , n_points : int = 200 , n_shapes : int = 5 , extra_coord_system : Optional [str ] = None
88+ ) -> None :
7989 """
8090 Blobs dataset.
8191
@@ -88,23 +98,28 @@ def __init__(self, length: int = 512, n_points: int = 200, n_shapes: int = 5) ->
8898 n_shapes
8999 Number of max shapes to generate.
90100 At most, as if overlapping they will be discarded
101+ extra_coord_system
102+ Extra coordinate space on top of the standard global coordinate space. Will have only identity transform.
91103 """
92104 self .length = length
93105 self .n_points = n_points
94106 self .n_shapes = n_shapes
107+ self .transformations = {"global" : Identity ()}
108+ if extra_coord_system :
109+ self .transformations [extra_coord_system ] = Identity ()
95110
96111 def blobs (
97112 self ,
98113 ) -> SpatialData :
99114 """Blobs dataset."""
100- image = self ._image_blobs (self .length )
101- multiscale_image = self ._image_blobs (self .length , multiscale = True )
102- labels = self ._labels_blobs (self .length )
103- multiscale_labels = self ._labels_blobs (self .length , multiscale = True )
104- points = self ._points_blobs (self .length , self .n_points )
105- circles = self ._circles_blobs (self .length , self .n_shapes )
106- polygons = self ._polygons_blobs (self .length , self .n_shapes )
107- multipolygons = self ._polygons_blobs (self .length , self .n_shapes , multipolygons = True )
115+ image = self ._image_blobs (self .transformations , self . length )
116+ multiscale_image = self ._image_blobs (self .transformations , self . length , multiscale = True )
117+ labels = self ._labels_blobs (self .transformations , self . length )
118+ multiscale_labels = self ._labels_blobs (self .transformations , self . length , multiscale = True )
119+ points = self ._points_blobs (self .transformations , self . length , self .n_points )
120+ circles = self ._circles_blobs (self .transformations , self . length , self .n_shapes )
121+ polygons = self ._polygons_blobs (self .transformations , self . length , self .n_shapes )
122+ multipolygons = self ._polygons_blobs (self .transformations , self . length , self .n_shapes , multipolygons = True )
108123 adata = aggregate (image , labels )
109124 adata .obs ["region" ] = pd .Categorical (["blobs_labels" ] * len (adata ))
110125 adata .obs ["instance_id" ] = adata .obs_names .astype (int )
@@ -118,7 +133,12 @@ def blobs(
118133 table = table ,
119134 )
120135
121- def _image_blobs (self , length : int = 512 , multiscale : bool = False ) -> Union [SpatialImage , MultiscaleSpatialImage ]:
136+ def _image_blobs (
137+ self ,
138+ transformations : Optional [dict [str , Any ]] = None ,
139+ length : int = 512 ,
140+ multiscale : bool = False ,
141+ ) -> Union [SpatialImage , MultiscaleSpatialImage ]:
122142 masks = []
123143 for i in range (3 ):
124144 mask = self ._generate_blobs (length = length , seed = i )
@@ -128,10 +148,12 @@ def _image_blobs(self, length: int = 512, multiscale: bool = False) -> Union[Spa
128148 x = np .stack (masks , axis = 0 )
129149 dims = ["c" , "y" , "x" ]
130150 if not multiscale :
131- return Image2DModel .parse (x , dims = dims )
132- return Image2DModel .parse (x , dims = dims , scale_factors = [2 , 2 ])
151+ return Image2DModel .parse (x , transformations = transformations , dims = dims )
152+ return Image2DModel .parse (x , transformations = transformations , dims = dims , scale_factors = [2 , 2 ])
133153
134- def _labels_blobs (self , length : int = 512 , multiscale : bool = False ) -> Union [SpatialImage , MultiscaleSpatialImage ]:
154+ def _labels_blobs (
155+ self , transformations : Optional [dict [str , Any ]] = None , length : int = 512 , multiscale : bool = False
156+ ) -> Union [SpatialImage , MultiscaleSpatialImage ]:
135157 """Create a 2D labels."""
136158 from scipy .ndimage import watershed_ift
137159
@@ -155,8 +177,8 @@ def _labels_blobs(self, length: int = 512, multiscale: bool = False) -> Union[Sp
155177 out [out == val [idx ]] = i
156178 dims = ["y" , "x" ]
157179 if not multiscale :
158- return Labels2DModel .parse (out , dims = dims )
159- return Labels2DModel .parse (out , dims = dims , scale_factors = [2 , 2 ])
180+ return Labels2DModel .parse (out , transformations = transformations , dims = dims )
181+ return Labels2DModel .parse (out , transformations = transformations , dims = dims , scale_factors = [2 , 2 ])
160182
161183 def _generate_blobs (self , length : int = 512 , seed : Optional [int ] = None ) -> ArrayLike :
162184 from scipy .ndimage import gaussian_filter
@@ -171,21 +193,27 @@ def _generate_blobs(self, length: int = 512, seed: Optional[int] = None) -> Arra
171193 mask = gaussian_filter (mask , sigma = 0.25 * length * 0.1 )
172194 return mask
173195
174- def _points_blobs (self , length : int = 512 , n_points : int = 200 ) -> DaskDataFrame :
196+ def _points_blobs (
197+ self , transformations : Optional [dict [str , Any ]] = None , length : int = 512 , n_points : int = 200
198+ ) -> DaskDataFrame :
175199 rng = default_rng (42 )
176- arr = rng .integers (10 , length - 10 , size = (n_points , 2 )).astype (np .int_ )
200+ arr = rng .integers (10 , length - 10 , size = (n_points , 2 )).astype (np .int64 )
177201 # randomly assign some values from v to the points
178- points_assignment0 = rng .integers (0 , 10 , size = arr .shape [0 ]).astype (np .int_ )
202+ points_assignment0 = rng .integers (0 , 10 , size = arr .shape [0 ]).astype (np .int64 )
179203 genes = rng .choice (["a" , "b" ], size = arr .shape [0 ])
180204 annotation = pd .DataFrame (
181205 {
182206 "genes" : genes ,
183207 "instance_id" : points_assignment0 ,
184208 },
185209 )
186- return PointsModel .parse (arr , annotation = annotation , feature_key = "genes" , instance_key = "instance_id" )
210+ return PointsModel .parse (
211+ arr , transformations = transformations , annotation = annotation , feature_key = "genes" , instance_key = "instance_id"
212+ )
187213
188- def _circles_blobs (self , length : int = 512 , n_shapes : int = 5 ) -> GeoDataFrame :
214+ def _circles_blobs (
215+ self , transformations : Optional [dict [str , Any ]] = None , length : int = 512 , n_shapes : int = 5
216+ ) -> GeoDataFrame :
189217 midpoint = length // 2
190218 halfmidpoint = midpoint // 2
191219 radius = length // 10
@@ -195,9 +223,15 @@ def _circles_blobs(self, length: int = 512, n_shapes: int = 5) -> GeoDataFrame:
195223 "radius" : radius ,
196224 }
197225 )
198- return ShapesModel .parse (circles )
226+ return ShapesModel .parse (circles , transformations = transformations )
199227
200- def _polygons_blobs (self , length : int = 512 , n_shapes : int = 5 , multipolygons : bool = False ) -> GeoDataFrame :
228+ def _polygons_blobs (
229+ self ,
230+ transformations : Optional [dict [str , Any ]] = None ,
231+ length : int = 512 ,
232+ n_shapes : int = 5 ,
233+ multipolygons : bool = False ,
234+ ) -> GeoDataFrame :
201235 midpoint = length // 2
202236 halfmidpoint = midpoint // 2
203237 poly = GeoDataFrame (
@@ -207,7 +241,7 @@ def _polygons_blobs(self, length: int = 512, n_shapes: int = 5, multipolygons: b
207241 )
208242 }
209243 )
210- return ShapesModel .parse (poly )
244+ return ShapesModel .parse (poly , transformations = transformations )
211245
212246 # function that generates random shapely polygons given a bounding box
213247 def _generate_random_polygons (
0 commit comments