Skip to content

Commit 14a6390

Browse files
committed
fix for negative offsets
1 parent 0421f8a commit 14a6390

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

src/xarray_multiscale/multiscale.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def downscale(
291291
new_coords = tuple(
292292
DataArray(
293293
(offset * abs(base_coords[bc][1] - base_coords[bc][0]))
294-
+ (base_coords[bc][:s] * sc),
294+
+ (base_coords[bc][:s] * sc)
295+
- base_coords[bc][0],
295296
name=base_coords[bc].name,
296297
attrs=base_coords[bc].attrs,
297298
)

tests/test_multiscale.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from xarray.core import dataarray
23
from xarray_multiscale.multiscale import (
34
downscale,
45
prepad,
@@ -9,6 +10,7 @@
910
import dask.array as da
1011
import numpy as np
1112
from xarray import DataArray
13+
from xarray.testing import assert_equal
1214

1315

1416
def test_downscale_depth():
@@ -23,11 +25,13 @@ def test_downscale_depth():
2325
assert get_downscale_depth((7, 3, 3), (2, 2, 2), pad=True) == 2
2426
assert get_downscale_depth((1500, 5495, 5200), (2, 2, 2)) == 10
2527

26-
@pytest.mark.parametrize(("size","scale"), ((10,2), (11,2), (12,2), (13,2)))
28+
29+
@pytest.mark.parametrize(("size", "scale"), ((10, 2), (11, 2), (12, 2), (13, 2)))
2730
def test_even_padding(size: int, scale: int) -> None:
2831
assert (size + even_padding(size, scale)) % scale == 0
2932

30-
@pytest.mark.parametrize('dim', (1,2,3,4))
33+
34+
@pytest.mark.parametrize("dim", (1, 2, 3, 4))
3135
def test_prepad(dim: int) -> None:
3236
size = (10,) * dim
3337
chunks = (9,) * dim
@@ -53,17 +57,14 @@ def test_downscale_2d():
5357
arr_dask = da.from_array(arr_numpy, chunks=chunks)
5458
arr_xarray = DataArray(arr_dask)
5559

56-
downscaled_numpy_float = downscale(
57-
arr_numpy, np.mean, scale).compute()
60+
downscaled_numpy_float = downscale(arr_numpy, np.mean, scale).compute()
5861

59-
downscaled_dask_float = downscale(
60-
arr_dask, np.mean, scale).compute()
62+
downscaled_dask_float = downscale(arr_dask, np.mean, scale).compute()
6163

62-
downscaled_xarray_float = downscale(
63-
arr_xarray, np.mean, scale).compute()
64+
downscaled_xarray_float = downscale(arr_xarray, np.mean, scale).compute()
6465

6566
answer_float = np.array([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]])
66-
67+
6768
assert np.array_equal(downscaled_numpy_float, answer_float)
6869
assert np.array_equal(downscaled_dask_float, answer_float)
6970
assert np.array_equal(downscaled_xarray_float, answer_float)
@@ -77,28 +78,52 @@ def test_multiscale():
7778
cell = np.zeros(np.prod(chunks)).astype("float")
7879
cell[0] = 1
7980
cell = cell.reshape(*chunks)
80-
array = np.tile(cell, np.ceil(np.divide(shape, chunks)).astype("int"))[cropslice]
81-
82-
pyr_trimmed = multiscale(array, np.mean, 2, pad_mode=None)
83-
pyr_padded = multiscale(array, np.mean, 2, pad_mode="reflect")
84-
pyr_trimmed_unchained = multiscale(array, np.mean, 2, pad_mode=None, chained=False)
81+
base_array = np.tile(cell, np.ceil(np.divide(shape, chunks)).astype("int"))[
82+
cropslice
83+
]
84+
pyr_trimmed = multiscale(base_array, np.mean, 2, pad_mode=None)
85+
pyr_padded = multiscale(base_array, np.mean, 2, pad_mode="reflect")
86+
pyr_trimmed_unchained = multiscale(
87+
base_array, np.mean, 2, pad_mode=None, chained=False
88+
)
8589
assert [p.shape for p in pyr_padded] == [
8690
shape,
8791
(5, 5, 5),
8892
(3, 3, 3),
8993
(2, 2, 2),
9094
(1, 1, 1),
9195
]
92-
assert [p.shape for p in pyr_trimmed] == [shape, (4, 4, 4), (2, 2, 2), (1,1,1)]
96+
assert [p.shape for p in pyr_trimmed] == [shape, (4, 4, 4), (2, 2, 2), (1, 1, 1)]
9397

9498
# check that the first multiscale array is identical to the input data
95-
assert np.array_equal(pyr_padded[0].data.compute(), array)
96-
assert np.array_equal(pyr_trimmed[0].data.compute(), array)
99+
assert np.array_equal(pyr_padded[0].data.compute(), base_array)
100+
assert np.array_equal(pyr_trimmed[0].data.compute(), base_array)
97101

98102
assert np.array_equal(
99103
pyr_trimmed[-2].data.mean().compute(), pyr_trimmed[-1].data.compute().mean()
100104
)
101105
assert np.array_equal(
102-
pyr_trimmed_unchained[-2].data.mean().compute(), pyr_trimmed_unchained[-1].data.compute().mean()
106+
pyr_trimmed_unchained[-2].data.mean().compute(),
107+
pyr_trimmed_unchained[-1].data.compute().mean(),
103108
)
104109
assert np.allclose(pyr_padded[0].data.mean().compute(), 0.17146776406035666)
110+
111+
112+
def test_coords():
113+
dims = ("z", "y", "x")
114+
shape = (16,) * len(dims)
115+
base_array = np.random.randint(0, 255, shape, dtype="uint8")
116+
117+
translates = (0.0, -10, 10)
118+
scales = (1.0, 2.0, 3.0)
119+
coords = tuple(
120+
(d, sc * (np.arange(shp) + tr))
121+
for d, sc, shp, tr in zip(dims, scales, base_array.shape, translates)
122+
)
123+
dataarray = DataArray(base_array, coords=coords)
124+
downscaled = dataarray.coarsen({"z": 2, "y": 2, "x": 2}).mean()
125+
126+
multi = multiscale(dataarray, np.mean, (2, 2, 2), preserve_dtype=False)
127+
128+
assert_equal(multi[0], dataarray)
129+
assert_equal(multi[1], downscaled)

0 commit comments

Comments
 (0)