Skip to content

Commit 9297e50

Browse files
jonasvddjvdd
andauthored
fix: pickle compatible implementation of downsamplers (#44)
* 🔥 first pickle compatible implementation of downsamplers * 🖊️ review code * 🙈 fix linting issue --------- Co-authored-by: jvdd <[email protected]>
1 parent ae60a28 commit 9297e50

File tree

3 files changed

+58
-18
lines changed

3 files changed

+58
-18
lines changed

tests/test_tsdownsample.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def generate_all_downsamplers() -> Iterable[AbstractDownsampler]:
3838

3939

4040
@pytest.mark.parametrize("downsampler", generate_all_downsamplers())
41-
def test_serialization(downsampler: AbstractDownsampler):
41+
def test_serialization_copy(downsampler: AbstractDownsampler):
4242
"""Test serialization."""
4343
from copy import copy, deepcopy
4444

@@ -53,6 +53,19 @@ def test_serialization(downsampler: AbstractDownsampler):
5353
assert np.all(orig_downsampled == ddc_downsampled)
5454

5555

56+
@pytest.mark.parametrize("downsampler", generate_all_downsamplers())
57+
def test_serialization_pickle(downsampler: AbstractDownsampler):
58+
"""Test serialization."""
59+
import pickle
60+
61+
dc = pickle.loads(pickle.dumps(downsampler))
62+
63+
arr = np.arange(10_000)
64+
orig_downsampled = downsampler.downsample(arr, n_out=100)
65+
dc_downsampled = dc.downsample(arr, n_out=100)
66+
assert np.all(orig_downsampled == dc_downsampled)
67+
68+
5669
@pytest.mark.parametrize("downsampler", generate_rust_downsamplers())
5770
def test_rust_downsampler(downsampler: AbstractDownsampler):
5871
"""Test the Rust downsamplers."""

tsdownsample/downsamplers.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111

1212
class MinMaxDownsampler(AbstractRustDownsampler):
13-
def __init__(self) -> None:
14-
super().__init__(_tsdownsample_rs.minmax)
13+
@property
14+
def rust_mod(self):
15+
return _tsdownsample_rs.minmax
1516

1617
@staticmethod
1718
def _check_valid_n_out(n_out: int):
@@ -21,8 +22,9 @@ def _check_valid_n_out(n_out: int):
2122

2223

2324
class M4Downsampler(AbstractRustDownsampler):
24-
def __init__(self):
25-
super().__init__(_tsdownsample_rs.m4)
25+
@property
26+
def rust_mod(self):
27+
return _tsdownsample_rs.m4
2628

2729
@staticmethod
2830
def _check_valid_n_out(n_out: int):
@@ -32,13 +34,15 @@ def _check_valid_n_out(n_out: int):
3234

3335

3436
class LTTBDownsampler(AbstractRustDownsampler):
35-
def __init__(self):
36-
super().__init__(_tsdownsample_rs.lttb)
37+
@property
38+
def rust_mod(self):
39+
return _tsdownsample_rs.lttb
3740

3841

3942
class MinMaxLTTBDownsampler(AbstractRustDownsampler):
40-
def __init__(self):
41-
super().__init__(_tsdownsample_rs.minmaxlttb)
43+
@property
44+
def rust_mod(self):
45+
return _tsdownsample_rs.minmaxlttb
4246

4347
def downsample(
4448
self, *args, n_out: int, minmax_ratio: int = 30, parallel: bool = False, **_

tsdownsample/downsampling_interface.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,47 @@ def downsample(self, *args, n_out: int, **kwargs): # x and y are optional
143143
class AbstractRustDownsampler(AbstractDownsampler, ABC):
144144
"""RustDownsampler interface-class, subclassed by concrete downsamplers."""
145145

146-
def __init__(self, resampling_mod: ModuleType):
146+
def __init__(self):
147147
super().__init__(_rust_dtypes, _y_rust_dtypes) # same for x and y
148-
self.rust_mod = resampling_mod
149148

150-
# Store the single core sub module
151-
self.mod_single_core = self.rust_mod.scalar
149+
@property
150+
def rust_mod(self) -> ModuleType:
151+
"""The compiled Rust module for the current downsampler."""
152+
raise NotImplementedError
153+
154+
@property
155+
def mod_single_core(self) -> ModuleType:
156+
"""Get the single-core Rust module.
157+
158+
Returns
159+
-------
160+
ModuleType
161+
If SIMD compiled module is available, that one is returned. Otherwise, the
162+
scalar compiled module is returned.
163+
"""
152164
if hasattr(self.rust_mod, "simd"):
153165
# use SIMD implementation if available
154-
self.mod_single_core = self.rust_mod.simd
166+
return self.rust_mod.simd
167+
return self.rust_mod.scalar
168+
169+
@property
170+
def mod_multi_core(self) -> Union[ModuleType, None]:
171+
"""Get the multi-core Rust module.
155172
156-
# Store the multi-core sub module (if present)
157-
self.mod_multi_core = None # no multi-core implementation (default)
173+
Returns
174+
-------
175+
ModuleType or None
176+
If SIMD parallel compiled module is available, that one is returned.
177+
Otherwise, the scalar parallel compiled module is returned.
178+
If no parallel compiled module is available, None is returned.
179+
"""
158180
if hasattr(self.rust_mod, "simd_parallel"):
159181
# use SIMD implementation if available
160-
self.mod_multi_core = self.rust_mod.simd_parallel
182+
return self.rust_mod.simd_parallel
161183
elif hasattr(self.rust_mod, "scalar_parallel"):
162184
# use scalar implementation if available (when no SIMD available)
163-
self.mod_multi_core = self.rust_mod.scalar_parallel
185+
return self.rust_mod.scalar_parallel
186+
return None # no parallel compiled module available
164187

165188
@staticmethod
166189
def _switch_mod_with_y(

0 commit comments

Comments
 (0)