Skip to content

Commit 26a4624

Browse files
committed
ruff formatting
1 parent f3c05c4 commit 26a4624

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

src/ezmsg/sigproc/downsample.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
@consumer
1616
def downsample(
17-
axis: str | None = None, target_rate: float | None = None
17+
axis: str | None = None, target_rate: float | None = None, factor: int | None = None
1818
) -> typing.Generator[AxisArray, AxisArray, None]:
1919
"""
2020
Construct a generator that yields a downsampled version of the data .send() to it.
@@ -28,6 +28,7 @@ def downsample(
2828
Note: The axis must exist in the message .axes and be of type AxisArray.LinearAxis.
2929
target_rate: Desired rate after downsampling. The actual rate will be the nearest integer factor of the
3030
input rate that is the same or higher than the target rate.
31+
factor: Explicitly specify downsample factor. If specified, target_rate is ignored.
3132
3233
Returns:
3334
A primed generator object ready to receive an :obj:`AxisArray` via `.send(axis_array)`
@@ -39,7 +40,7 @@ def downsample(
3940
msg_out = AxisArray(np.array([]), dims=[""])
4041

4142
# state variables
42-
factor: int = 0 # The integer downsampling factor. It will be determined based on the target rate.
43+
q: int = 0 # The integer downsampling factor. It will be determined based on the target rate.
4344
s_idx: int = 0 # Index of the next msg's first sample into the virtual rotating ds_factor counter.
4445

4546
check_input = {"gain": None, "key": None}
@@ -61,19 +62,21 @@ def downsample(
6162
check_input["key"] = msg_in.key
6263
# Reset state variables
6364
s_idx = 0
64-
if target_rate is None:
65-
factor = 1
65+
if factor is not None:
66+
q = factor
67+
elif target_rate is None:
68+
q = 1
6669
else:
67-
factor = int(1 / (axis_info.gain * target_rate))
68-
if factor < 1:
70+
q = int(1 / (axis_info.gain * target_rate))
71+
if q < 1:
6972
ez.logger.warning(
7073
f"Target rate {target_rate} cannot be achieved with input rate of {1/axis_info.gain}."
7174
"Setting factor to 1."
7275
)
73-
factor = 1
76+
q = 1
7477

7578
n_samples = msg_in.data.shape[axis_idx]
76-
samples = np.arange(s_idx, s_idx + n_samples) % factor
79+
samples = np.arange(s_idx, s_idx + n_samples) % q
7780
if n_samples > 0:
7881
# Update state for next iteration.
7982
s_idx = samples[-1] + 1
@@ -92,7 +95,7 @@ def downsample(
9295
**msg_in.axes,
9396
axis: replace(
9497
axis_info,
95-
gain=axis_info.gain * factor,
98+
gain=axis_info.gain * q,
9699
offset=axis_info.offset + axis_info.gain * n_step,
97100
),
98101
},
@@ -107,6 +110,7 @@ class DownsampleSettings(ez.Settings):
107110

108111
axis: str | None = None
109112
target_rate: float | None = None
113+
factor: int | None = None
110114

111115

112116
class Downsample(GenAxisArray):
@@ -116,5 +120,7 @@ class Downsample(GenAxisArray):
116120

117121
def construct_generator(self):
118122
self.STATE.gen = downsample(
119-
axis=self.SETTINGS.axis, target_rate=self.SETTINGS.target_rate
123+
axis=self.SETTINGS.axis,
124+
target_rate=self.SETTINGS.target_rate,
125+
factor=self.SETTINGS.factor,
120126
)

tests/test_downsample.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
@pytest.mark.parametrize("block_size", [1, 5, 10, 20])
2323
@pytest.mark.parametrize("target_rate", [19.0, 9.5, 6.3])
24-
def test_downsample_core(block_size: int, target_rate: float):
24+
@pytest.mark.parametrize("factor", [None, 1, 2])
25+
def test_downsample_core(block_size: int, target_rate: float, factor: int | None):
2526
in_fs = 19.0
2627
test_dur = 4.0
2728
n_channels = 2
@@ -60,7 +61,7 @@ def msg_generator():
6061
in_msgs = list(msg_generator())
6162
backup = [copy.deepcopy(msg) for msg in in_msgs]
6263

63-
proc = downsample(axis="time", target_rate=target_rate)
64+
proc = downsample(axis="time", target_rate=target_rate, factor=factor)
6465
out_msgs = []
6566
for msg in in_msgs:
6667
res = proc.send(msg)
@@ -70,7 +71,7 @@ def msg_generator():
7071
assert_messages_equal(in_msgs, backup)
7172

7273
# Assert correctness of gain
73-
expected_factor: int = int(in_fs // target_rate)
74+
expected_factor: int = int(in_fs // target_rate) if factor is None else factor
7475
assert all(msg.axes["time"].gain == expected_factor / in_fs for msg in out_msgs)
7576

7677
# Assert messages have the correct timestamps
@@ -132,7 +133,13 @@ def network(self) -> ez.NetworkDefinition:
132133

133134
@pytest.mark.parametrize("block_size", [10])
134135
@pytest.mark.parametrize("target_rate", [6.3])
135-
def test_downsample_system(block_size: int, target_rate: float, test_name: str | None = None):
136+
@pytest.mark.parametrize("factor", [None, 2])
137+
def test_downsample_system(
138+
block_size: int,
139+
target_rate: float,
140+
factor: int | None,
141+
test_name: str | None = None,
142+
):
136143
in_fs = 19.0
137144
num_msgs = int(4.0 / (block_size / in_fs)) # Ensure 4 seconds of data
138145

@@ -146,7 +153,7 @@ def test_downsample_system(block_size: int, target_rate: float, test_name: str |
146153
fs=in_fs,
147154
dispatch_rate=20.0,
148155
),
149-
down_settings=DownsampleSettings(target_rate=target_rate),
156+
down_settings=DownsampleSettings(target_rate=target_rate, factor=factor),
150157
log_settings=MessageLoggerSettings(output=test_filename),
151158
term_settings=TerminateTestSettings(time=1.0),
152159
)
@@ -160,7 +167,7 @@ def test_downsample_system(block_size: int, target_rate: float, test_name: str |
160167
ez.logger.info(f"Analyzing recording of { len( messages ) } messages...")
161168

162169
# Check fs
163-
expected_factor: int = int(in_fs // target_rate)
170+
expected_factor: int = int(in_fs // target_rate) if factor is None else factor
164171
out_fs = in_fs / expected_factor
165172
assert np.allclose(
166173
np.array([1 / msg.axes["time"].gain for msg in messages]),

0 commit comments

Comments
 (0)