Skip to content

Commit 27c08fa

Browse files
authored
Merge pull request #496 from will-moore/scaler_preserve_dtype
Ensure that laplacian and gaussian scaling preserves dtype
2 parents fd57701 + fd18517 commit 27c08fa

File tree

2 files changed

+34
-33
lines changed

2 files changed

+34
-33
lines changed

ome_zarr/scale.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,25 +203,25 @@ def _resize(
203203

204204
def gaussian(self, base: np.ndarray) -> list[np.ndarray]:
205205
"""Downsample using :func:`skimage.transform.pyramid_gaussian`."""
206-
return list(
207-
pyramid_gaussian(
208-
base,
209-
downscale=self.downscale,
210-
max_layer=self.max_layer,
211-
channel_axis=None,
212-
)
206+
dtype = base.dtype
207+
pyramid = pyramid_gaussian(
208+
base,
209+
downscale=self.downscale,
210+
max_layer=self.max_layer,
211+
channel_axis=None,
213212
)
213+
return [level.astype(dtype) for level in pyramid]
214214

215215
def laplacian(self, base: np.ndarray) -> list[np.ndarray]:
216216
"""Downsample using :func:`skimage.transform.pyramid_laplacian`."""
217-
return list(
218-
pyramid_laplacian(
219-
base,
220-
downscale=self.downscale,
221-
max_layer=self.max_layer,
222-
channel_axis=None,
223-
)
217+
dtype = base.dtype
218+
pyramid = pyramid_laplacian(
219+
base,
220+
downscale=self.downscale,
221+
max_layer=self.max_layer,
222+
channel_axis=None,
224223
)
224+
return [level.astype(dtype) for level in pyramid]
225225

226226
def local_mean(self, base: np.ndarray) -> list[np.ndarray]:
227227
"""Downsample using :func:`skimage.transform.downscale_local_mean`."""

tests/test_scaler.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,21 @@ def create_data(self, shape, dtype=np.uint8, mean_val=10):
2525
rng = np.random.default_rng(0)
2626
return rng.poisson(mean_val, size=shape).astype(dtype)
2727

28-
def check_downscaled(self, downscaled, shape, scale_factor=2):
29-
expected_shape = shape
30-
for data in downscaled:
31-
assert data.shape == expected_shape
32-
assert data.dtype == downscaled[0].dtype
33-
expected_shape = expected_shape[:-2] + tuple(
34-
sh // scale_factor for sh in expected_shape[-2:]
35-
)
28+
def check_downscaled(self, downscaled, data, scale_factor=2):
29+
expected_shape = data.shape
30+
for level in downscaled:
31+
assert level.dtype == data.dtype
32+
if scale_factor is not None:
33+
assert level.shape == expected_shape
34+
expected_shape = expected_shape[:-2] + tuple(
35+
sh // scale_factor for sh in expected_shape[-2:]
36+
)
3637

3738
def test_nearest(self, shape):
3839
data = self.create_data(shape)
3940
scaler = Scaler()
4041
downscaled = scaler.nearest(data)
41-
self.check_downscaled(downscaled, shape)
42+
self.check_downscaled(downscaled, data)
4243

4344
def test_nearest_via_method(self, shape):
4445
data = self.create_data(shape)
@@ -48,7 +49,7 @@ def test_nearest_via_method(self, shape):
4849

4950
scaler.method = "nearest"
5051
downscaled = scaler.func(data)
51-
self.check_downscaled(downscaled, shape)
52+
self.check_downscaled(downscaled, data)
5253

5354
assert (
5455
np.sum(
@@ -60,27 +61,27 @@ def test_nearest_via_method(self, shape):
6061
== 0
6162
)
6263

63-
# this fails because of wrong channel dimension; need to fix in follow-up PR
64-
@pytest.mark.xfail
64+
# NB: gaussian downscales ALL dimensions, not just YX
65+
# so we SKIP the check on shape
6566
def test_gaussian(self, shape):
6667
data = self.create_data(shape)
6768
scaler = Scaler()
6869
downscaled = scaler.gaussian(data)
69-
self.check_downscaled(downscaled, shape)
70+
self.check_downscaled(downscaled, data, scale_factor=None)
7071

71-
# this fails because of wrong channel dimension; need to fix in follow-up PR
72-
@pytest.mark.xfail
72+
# NB: laplacian downscales ALL dimensions, not just YX
73+
# so we SKIP the check on shape
7374
def test_laplacian(self, shape):
7475
data = self.create_data(shape)
7576
scaler = Scaler()
7677
downscaled = scaler.laplacian(data)
77-
self.check_downscaled(downscaled, shape)
78+
self.check_downscaled(downscaled, data, scale_factor=None)
7879

7980
def test_local_mean(self, shape):
8081
data = self.create_data(shape)
8182
scaler = Scaler()
8283
downscaled = scaler.local_mean(data)
83-
self.check_downscaled(downscaled, shape)
84+
self.check_downscaled(downscaled, data)
8485

8586
def test_local_mean_via_method(self, shape):
8687
data = self.create_data(shape)
@@ -90,7 +91,7 @@ def test_local_mean_via_method(self, shape):
9091

9192
scaler.method = "local_mean"
9293
downscaled = scaler.func(data)
93-
self.check_downscaled(downscaled, shape)
94+
self.check_downscaled(downscaled, data)
9495

9596
assert (
9697
np.sum(
@@ -107,7 +108,7 @@ def test_zoom(self, shape):
107108
data = self.create_data(shape)
108109
scaler = Scaler()
109110
downscaled = scaler.zoom(data)
110-
self.check_downscaled(downscaled, shape)
111+
self.check_downscaled(downscaled, data)
111112

112113
def test_scale_dask(self, shape):
113114
data = self.create_data(shape)

0 commit comments

Comments
 (0)