@@ -25,20 +25,21 @@ def create_data(self, shape, dtype=np.uint8, mean_val=10):
2525 rng = np .random .default_rng (0 )
2626 return rng .poisson (mean_val , size = shape ).astype (dtype )
2727
28- def check_downscaled (self , downscaled , shape , scale_factor = 2 ):
29- expected_shape = shape
30- for data in downscaled :
31- assert data .shape == expected_shape
32- assert data .dtype == downscaled [0 ].dtype
33- expected_shape = expected_shape [:- 2 ] + tuple (
34- sh // scale_factor for sh in expected_shape [- 2 :]
35- )
28+ def check_downscaled (self , downscaled , data , scale_factor = 2 ):
29+ expected_shape = data .shape
30+ for level in downscaled :
31+ assert level .dtype == data .dtype
32+ if scale_factor is not None :
33+ assert level .shape == expected_shape
34+ expected_shape = expected_shape [:- 2 ] + tuple (
35+ sh // scale_factor for sh in expected_shape [- 2 :]
36+ )
3637
3738 def test_nearest (self , shape ):
3839 data = self .create_data (shape )
3940 scaler = Scaler ()
4041 downscaled = scaler .nearest (data )
41- self .check_downscaled (downscaled , shape )
42+ self .check_downscaled (downscaled , data )
4243
4344 def test_nearest_via_method (self , shape ):
4445 data = self .create_data (shape )
@@ -48,7 +49,7 @@ def test_nearest_via_method(self, shape):
4849
4950 scaler .method = "nearest"
5051 downscaled = scaler .func (data )
51- self .check_downscaled (downscaled , shape )
52+ self .check_downscaled (downscaled , data )
5253
5354 assert (
5455 np .sum (
@@ -60,27 +61,27 @@ def test_nearest_via_method(self, shape):
6061 == 0
6162 )
6263
63- # this fails because of wrong channel dimension; need to fix in follow-up PR
64- @ pytest . mark . xfail
64+ # NB: gaussian downscales ALL dimensions, not just YX
65+ # so we SKIP the check on shape
6566 def test_gaussian (self , shape ):
6667 data = self .create_data (shape )
6768 scaler = Scaler ()
6869 downscaled = scaler .gaussian (data )
69- self .check_downscaled (downscaled , shape )
70+ self .check_downscaled (downscaled , data , scale_factor = None )
7071
71- # this fails because of wrong channel dimension; need to fix in follow-up PR
72- @ pytest . mark . xfail
72+ # NB: laplacian downscales ALL dimensions, not just YX
73+ # so we SKIP the check on shape
7374 def test_laplacian (self , shape ):
7475 data = self .create_data (shape )
7576 scaler = Scaler ()
7677 downscaled = scaler .laplacian (data )
77- self .check_downscaled (downscaled , shape )
78+ self .check_downscaled (downscaled , data , scale_factor = None )
7879
7980 def test_local_mean (self , shape ):
8081 data = self .create_data (shape )
8182 scaler = Scaler ()
8283 downscaled = scaler .local_mean (data )
83- self .check_downscaled (downscaled , shape )
84+ self .check_downscaled (downscaled , data )
8485
8586 def test_local_mean_via_method (self , shape ):
8687 data = self .create_data (shape )
@@ -90,7 +91,7 @@ def test_local_mean_via_method(self, shape):
9091
9192 scaler .method = "local_mean"
9293 downscaled = scaler .func (data )
93- self .check_downscaled (downscaled , shape )
94+ self .check_downscaled (downscaled , data )
9495
9596 assert (
9697 np .sum (
@@ -107,7 +108,7 @@ def test_zoom(self, shape):
107108 data = self .create_data (shape )
108109 scaler = Scaler ()
109110 downscaled = scaler .zoom (data )
110- self .check_downscaled (downscaled , shape )
111+ self .check_downscaled (downscaled , data )
111112
112113 def test_scale_dask (self , shape ):
113114 data = self .create_data (shape )
0 commit comments