Skip to content

Commit 39902d2

Browse files
committed
Make transforms to optional
1 parent 02be003 commit 39902d2

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

src/spatialdata/datasets.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def blobs(
135135

136136
def _image_blobs(
137137
self,
138-
transforms: dict[str, Any],
138+
transforms: Optional[dict[str, Any]] = None,
139139
length: int = 512,
140140
multiscale: bool = False,
141141
) -> Union[SpatialImage, MultiscaleSpatialImage]:
@@ -152,7 +152,7 @@ def _image_blobs(
152152
return Image2DModel.parse(x, transformations=transforms, dims=dims, scale_factors=[2, 2])
153153

154154
def _labels_blobs(
155-
self, transforms: dict[str, Any], length: int = 512, multiscale: bool = False
155+
self, transforms: Optional[dict[str, Any]] = None, length: int = 512, multiscale: bool = False
156156
) -> Union[SpatialImage, MultiscaleSpatialImage]:
157157
"""Create a 2D labels."""
158158
from scipy.ndimage import watershed_ift
@@ -193,7 +193,9 @@ def _generate_blobs(self, length: int = 512, seed: Optional[int] = None) -> Arra
193193
mask = gaussian_filter(mask, sigma=0.25 * length * 0.1)
194194
return mask
195195

196-
def _points_blobs(self, transforms: dict[str, Any], length: int = 512, n_points: int = 200) -> DaskDataFrame:
196+
def _points_blobs(
197+
self, transforms: Optional[dict[str, Any]] = None, length: int = 512, n_points: int = 200
198+
) -> DaskDataFrame:
197199
rng = default_rng(42)
198200
arr = rng.integers(10, length - 10, size=(n_points, 2)).astype(np.int64)
199201
# randomly assign some values from v to the points
@@ -209,7 +211,9 @@ def _points_blobs(self, transforms: dict[str, Any], length: int = 512, n_points:
209211
arr, transformations=transforms, annotation=annotation, feature_key="genes", instance_key="instance_id"
210212
)
211213

212-
def _circles_blobs(self, transforms: dict[str, Any], length: int = 512, n_shapes: int = 5) -> GeoDataFrame:
214+
def _circles_blobs(
215+
self, transforms: Optional[dict[str, Any]] = None, length: int = 512, n_shapes: int = 5
216+
) -> GeoDataFrame:
213217
midpoint = length // 2
214218
halfmidpoint = midpoint // 2
215219
radius = length // 10
@@ -223,7 +227,7 @@ def _circles_blobs(self, transforms: dict[str, Any], length: int = 512, n_shapes
223227

224228
def _polygons_blobs(
225229
self,
226-
transforms: dict[str, Any],
230+
transforms: Optional[dict[str, Any]] = None,
227231
length: int = 512,
228232
n_shapes: int = 5,
229233
multipolygons: bool = False,

0 commit comments

Comments
 (0)