21
21
22
22
@pytest .mark .parametrize ("block_size" , [1 , 5 , 10 , 20 ])
23
23
@pytest .mark .parametrize ("target_rate" , [19.0 , 9.5 , 6.3 ])
24
- def test_downsample_core (block_size : int , target_rate : float ):
24
+ @pytest .mark .parametrize ("factor" , [None , 1 , 2 ])
25
+ def test_downsample_core (block_size : int , target_rate : float , factor : int | None ):
25
26
in_fs = 19.0
26
27
test_dur = 4.0
27
28
n_channels = 2
@@ -60,7 +61,7 @@ def msg_generator():
60
61
in_msgs = list (msg_generator ())
61
62
backup = [copy .deepcopy (msg ) for msg in in_msgs ]
62
63
63
- proc = downsample (axis = "time" , target_rate = target_rate )
64
+ proc = downsample (axis = "time" , target_rate = target_rate , factor = factor )
64
65
out_msgs = []
65
66
for msg in in_msgs :
66
67
res = proc .send (msg )
@@ -70,7 +71,7 @@ def msg_generator():
70
71
assert_messages_equal (in_msgs , backup )
71
72
72
73
# Assert correctness of gain
73
- expected_factor : int = int (in_fs // target_rate )
74
+ expected_factor : int = int (in_fs // target_rate ) if factor is None else factor
74
75
assert all (msg .axes ["time" ].gain == expected_factor / in_fs for msg in out_msgs )
75
76
76
77
# Assert messages have the correct timestamps
@@ -132,7 +133,13 @@ def network(self) -> ez.NetworkDefinition:
132
133
133
134
@pytest .mark .parametrize ("block_size" , [10 ])
134
135
@pytest .mark .parametrize ("target_rate" , [6.3 ])
135
- def test_downsample_system (block_size : int , target_rate : float , test_name : str | None = None ):
136
+ @pytest .mark .parametrize ("factor" , [None , 2 ])
137
+ def test_downsample_system (
138
+ block_size : int ,
139
+ target_rate : float ,
140
+ factor : int | None ,
141
+ test_name : str | None = None ,
142
+ ):
136
143
in_fs = 19.0
137
144
num_msgs = int (4.0 / (block_size / in_fs )) # Ensure 4 seconds of data
138
145
@@ -146,7 +153,7 @@ def test_downsample_system(block_size: int, target_rate: float, test_name: str |
146
153
fs = in_fs ,
147
154
dispatch_rate = 20.0 ,
148
155
),
149
- down_settings = DownsampleSettings (target_rate = target_rate ),
156
+ down_settings = DownsampleSettings (target_rate = target_rate , factor = factor ),
150
157
log_settings = MessageLoggerSettings (output = test_filename ),
151
158
term_settings = TerminateTestSettings (time = 1.0 ),
152
159
)
@@ -160,7 +167,7 @@ def test_downsample_system(block_size: int, target_rate: float, test_name: str |
160
167
ez .logger .info (f"Analyzing recording of { len ( messages ) } messages..." )
161
168
162
169
# Check fs
163
- expected_factor : int = int (in_fs // target_rate )
170
+ expected_factor : int = int (in_fs // target_rate ) if factor is None else factor
164
171
out_fs = in_fs / expected_factor
165
172
assert np .allclose (
166
173
np .array ([1 / msg .axes ["time" ].gain for msg in messages ]),
0 commit comments