11import pytest
2+ from xarray .core import dataarray
23from xarray_multiscale .multiscale import (
34 downscale ,
45 prepad ,
910import dask .array as da
1011import numpy as np
1112from xarray import DataArray
13+ from xarray .testing import assert_equal
1214
1315
1416def test_downscale_depth ():
@@ -23,11 +25,13 @@ def test_downscale_depth():
2325 assert get_downscale_depth ((7 , 3 , 3 ), (2 , 2 , 2 ), pad = True ) == 2
2426 assert get_downscale_depth ((1500 , 5495 , 5200 ), (2 , 2 , 2 )) == 10
2527
26- @pytest .mark .parametrize (("size" ,"scale" ), ((10 ,2 ), (11 ,2 ), (12 ,2 ), (13 ,2 )))
28+
29+ @pytest .mark .parametrize (("size" , "scale" ), ((10 , 2 ), (11 , 2 ), (12 , 2 ), (13 , 2 )))
2730def test_even_padding (size : int , scale : int ) -> None :
2831 assert (size + even_padding (size , scale )) % scale == 0
2932
30- @pytest .mark .parametrize ('dim' , (1 ,2 ,3 ,4 ))
33+
34+ @pytest .mark .parametrize ("dim" , (1 , 2 , 3 , 4 ))
3135def test_prepad (dim : int ) -> None :
3236 size = (10 ,) * dim
3337 chunks = (9 ,) * dim
@@ -53,17 +57,14 @@ def test_downscale_2d():
5357 arr_dask = da .from_array (arr_numpy , chunks = chunks )
5458 arr_xarray = DataArray (arr_dask )
5559
56- downscaled_numpy_float = downscale (
57- arr_numpy , np .mean , scale ).compute ()
60+ downscaled_numpy_float = downscale (arr_numpy , np .mean , scale ).compute ()
5861
59- downscaled_dask_float = downscale (
60- arr_dask , np .mean , scale ).compute ()
62+ downscaled_dask_float = downscale (arr_dask , np .mean , scale ).compute ()
6163
62- downscaled_xarray_float = downscale (
63- arr_xarray , np .mean , scale ).compute ()
64+ downscaled_xarray_float = downscale (arr_xarray , np .mean , scale ).compute ()
6465
6566 answer_float = np .array ([[0.5 , 0.5 , 0.5 , 0.5 ], [0.5 , 0.5 , 0.5 , 0.5 ]])
66-
67+
6768 assert np .array_equal (downscaled_numpy_float , answer_float )
6869 assert np .array_equal (downscaled_dask_float , answer_float )
6970 assert np .array_equal (downscaled_xarray_float , answer_float )
@@ -77,28 +78,52 @@ def test_multiscale():
7778 cell = np .zeros (np .prod (chunks )).astype ("float" )
7879 cell [0 ] = 1
7980 cell = cell .reshape (* chunks )
80- array = np .tile (cell , np .ceil (np .divide (shape , chunks )).astype ("int" ))[cropslice ]
81-
82- pyr_trimmed = multiscale (array , np .mean , 2 , pad_mode = None )
83- pyr_padded = multiscale (array , np .mean , 2 , pad_mode = "reflect" )
84- pyr_trimmed_unchained = multiscale (array , np .mean , 2 , pad_mode = None , chained = False )
81+ base_array = np .tile (cell , np .ceil (np .divide (shape , chunks )).astype ("int" ))[
82+ cropslice
83+ ]
84+ pyr_trimmed = multiscale (base_array , np .mean , 2 , pad_mode = None )
85+ pyr_padded = multiscale (base_array , np .mean , 2 , pad_mode = "reflect" )
86+ pyr_trimmed_unchained = multiscale (
87+ base_array , np .mean , 2 , pad_mode = None , chained = False
88+ )
8589 assert [p .shape for p in pyr_padded ] == [
8690 shape ,
8791 (5 , 5 , 5 ),
8892 (3 , 3 , 3 ),
8993 (2 , 2 , 2 ),
9094 (1 , 1 , 1 ),
9195 ]
92- assert [p .shape for p in pyr_trimmed ] == [shape , (4 , 4 , 4 ), (2 , 2 , 2 ), (1 ,1 , 1 )]
96+ assert [p .shape for p in pyr_trimmed ] == [shape , (4 , 4 , 4 ), (2 , 2 , 2 ), (1 , 1 , 1 )]
9397
9498 # check that the first multiscale array is identical to the input data
95- assert np .array_equal (pyr_padded [0 ].data .compute (), array )
96- assert np .array_equal (pyr_trimmed [0 ].data .compute (), array )
99+ assert np .array_equal (pyr_padded [0 ].data .compute (), base_array )
100+ assert np .array_equal (pyr_trimmed [0 ].data .compute (), base_array )
97101
98102 assert np .array_equal (
99103 pyr_trimmed [- 2 ].data .mean ().compute (), pyr_trimmed [- 1 ].data .compute ().mean ()
100104 )
101105 assert np .array_equal (
102- pyr_trimmed_unchained [- 2 ].data .mean ().compute (), pyr_trimmed_unchained [- 1 ].data .compute ().mean ()
106+ pyr_trimmed_unchained [- 2 ].data .mean ().compute (),
107+ pyr_trimmed_unchained [- 1 ].data .compute ().mean (),
103108 )
104109 assert np .allclose (pyr_padded [0 ].data .mean ().compute (), 0.17146776406035666 )
110+
111+
112+ def test_coords ():
113+ dims = ("z" , "y" , "x" )
114+ shape = (16 ,) * len (dims )
115+ base_array = np .random .randint (0 , 255 , shape , dtype = "uint8" )
116+
117+ translates = (0.0 , - 10 , 10 )
118+ scales = (1.0 , 2.0 , 3.0 )
119+ coords = tuple (
120+ (d , sc * (np .arange (shp ) + tr ))
121+ for d , sc , shp , tr in zip (dims , scales , base_array .shape , translates )
122+ )
123+ dataarray = DataArray (base_array , coords = coords )
124+ downscaled = dataarray .coarsen ({"z" : 2 , "y" : 2 , "x" : 2 }).mean ()
125+
126+ multi = multiscale (dataarray , np .mean , (2 , 2 , 2 ), preserve_dtype = False )
127+
128+ assert_equal (multi [0 ], dataarray )
129+ assert_equal (multi [1 ], downscaled )
0 commit comments