Skip to content

Commit e9d87eb

Browse files
authored
JP-4323: Improve performance of tso median calculation in outlier detection (#10452)
1 parent 34543ea commit e9d87eb

3 files changed

Lines changed: 32 additions & 22 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Reduce memory usage and improve runtime for rolling median used for tso data. For a MIRI dataset including 2 files each with 40 integrations the memory usage was reduced by a factor of 7 and runtime improved by a factor of 3.
Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
import numpy as np
2+
import pytest
23

34
from jwst.outlier_detection.tso import moving_median_over_zeroth_axis
45

56

6-
def test_rolling_median():
7+
@pytest.mark.parametrize(
8+
"w, expected_time_axis",
9+
[
10+
(3, [1, 1, 2, 3, 4, 5, 6, 7, 8, 8]),
11+
(4, [1.5, 1.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 7.5]),
12+
],
13+
)
14+
def test_rolling_median(w, expected_time_axis):
715
time_axis = np.arange(10)
8-
expected_time_axis = np.array([1, 1, 2, 3, 4, 5, 6, 7, 8, 8])
916
spatial_axis = np.ones((5, 5))
1017
arr = time_axis[:, np.newaxis, np.newaxis] * spatial_axis[np.newaxis, :, :]
1118

12-
w = 3
1319
result = moving_median_over_zeroth_axis(arr, w)
14-
expected = expected_time_axis[:, np.newaxis, np.newaxis] * spatial_axis[np.newaxis, :, :]
20+
expected = (
21+
np.array(expected_time_axis)[:, np.newaxis, np.newaxis] * spatial_axis[np.newaxis, :, :]
22+
)
1523
assert np.allclose(result, expected)

jwst/outlier_detection/tso.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,14 @@ def moving_median_over_zeroth_axis(x: np.ndarray, w: int) -> np.ndarray:
163163
"""
164164
Calculate the median of a moving window over the zeroth axis of an N-d array.
165165
166-
Algorithm works by expanding the array into an additional dimension
167-
where the new axis has the same length as the window size. Each entry in that
168-
axis is a copy of the original array shifted by 1 with respect to the previous
169-
entry, such that the rolling median is simply the median over the new axis.
170-
modified from https://stackoverflow.com/a/71154394, see link for more details.
166+
Slide a window of size w over the array along axis 0, and for each position,
167+
calculate the median of the values inside that window (across axis 0 only).
168+
The result at each step is stored in the center position of the window,
169+
producing an output array with the same shape as the input.
170+
171+
Because the window cannot fully overlap the data at the beginning and end,
172+
those edge positions are filled with the nearest computed median value to
173+
avoid missing data.
171174
172175
Parameters
173176
----------
@@ -183,17 +186,15 @@ def moving_median_over_zeroth_axis(x: np.ndarray, w: int) -> np.ndarray:
183186
"""
184187
if w <= 1:
185188
raise ValueError("Rolling median window size must be greater than 1.")
186-
shifted = np.zeros((x.shape[0] + w - 1, w, *x.shape[1:])) * np.nan
187-
for idx in range(w - 1):
188-
shifted[idx : -w + idx + 1, idx] = x
189-
shifted[idx + 1 :, idx + 1] = x
190-
medians: np.ndarray = np.median(shifted, axis=1)
191-
for idx in range(w - 1):
192-
medians[idx] = np.median(shifted[idx, : idx + 1])
193-
medians[-idx - 1] = np.median(shifted[-idx - 1, -idx - 1 :])
194-
medians = medians[(w - 1) // 2 : -(w - 1) // 2]
195-
189+
out = np.full(x.shape, np.nan)
190+
hw, odd_window = divmod(w, 2)
191+
for start_index in range(x.shape[0] - w + 1):
192+
end_index = start_index + w
193+
np.median(x[start_index:end_index], axis=0, out=out[start_index + hw])
196194
# Fill in the edges with the nearest valid value
197-
medians[: w // 2] = medians[w // 2]
198-
medians[-w // 2 :] = medians[-w // 2]
199-
return medians
195+
out[:hw] = out[hw]
196+
if odd_window:
197+
out[-hw:] = out[-hw - 1]
198+
else:
199+
out[-hw + 1 :] = out[-hw]
200+
return out

0 commit comments

Comments
 (0)