Skip to content

Commit cfa49ff

Browse files
Convenience API for modifying Single Run Data (#43)
* Convenience API for modifying Single Run Data * Fixing discrepancies
1 parent 7ed9cc0 commit cfa49ff

2 files changed

Lines changed: 162 additions & 80 deletions

File tree

perda/core_data_structures/single_run_data.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,96 @@ def __getitem__(
7575
else:
7676
raise ValueError("Input must be a string, int, or DataInstance.")
7777

78+
def __setitem__(self, cpp_name: str, di: DataInstance) -> None:
79+
"""
80+
Add or replace a variable using dictionary-style assignment.
81+
82+
Dispatches to ``replace`` if ``cpp_name`` already exists, else ``add``.
83+
84+
Parameters
85+
----------
86+
cpp_name : str
87+
C++ variable name to insert or overwrite.
88+
di : DataInstance
89+
DataInstance to store. Must have non-None ``label`` and ``cpp_name``.
90+
91+
Examples
92+
--------
93+
>>> data["my.new.var"] = DataInstance(timestamp_np=ts, value_np=vals, label="My var", cpp_name="my.new.var")
94+
>>> data["my.existing.var"] = updated_di
95+
"""
96+
if cpp_name in self:
97+
self.replace(cpp_name, di)
98+
else:
99+
self.add(cpp_name, di)
100+
101+
def add(self, cpp_name: str, di: DataInstance) -> None:
102+
"""
103+
Insert a new derived DataInstance using a synthetic negative ID.
104+
105+
Parameters
106+
----------
107+
cpp_name : str
108+
C++ variable name key for the new variable.
109+
di : DataInstance
110+
DataInstance to insert.
111+
"""
112+
if cpp_name in self:
113+
raise KeyError(f"'{cpp_name}' already exists; use replace() to overwrite.")
114+
115+
if di.cpp_name != cpp_name:
116+
print(f"Warning: replacing DataInstance.cpp_name with {cpp_name}")
117+
118+
synthetic_id = -(len(self.id_to_instance) + 1)
119+
if di.var_id != synthetic_id:
120+
print(f"Warning: replacing DataInstance.var_id with {synthetic_id}")
121+
122+
stored = DataInstance(
123+
timestamp_np=di.timestamp_np,
124+
value_np=di.value_np,
125+
label=di.label,
126+
var_id=synthetic_id,
127+
cpp_name=cpp_name,
128+
)
129+
self.id_to_instance[synthetic_id] = stored
130+
self.cpp_name_to_id[cpp_name] = synthetic_id
131+
self.id_to_cpp_name[synthetic_id] = cpp_name
132+
self.id_to_descript[synthetic_id] = di.label or ""
133+
134+
def replace(self, cpp_name: str, di: DataInstance) -> None:
135+
"""
136+
Overwrite the values of an existing variable in-place.
137+
138+
Parameters
139+
----------
140+
cpp_name : str
141+
C++ variable name of the variable to replace.
142+
di : DataInstance
143+
DataInstance whose ``value_np`` that replaces the stored one.
144+
"""
145+
if cpp_name not in self:
146+
raise KeyError(
147+
f"'{cpp_name}' not found; use add() to insert a new variable."
148+
)
149+
150+
var_id = self.cpp_name_to_id[cpp_name]
151+
old = self.id_to_instance[var_id]
152+
153+
if di.cpp_name != cpp_name:
154+
print(f"Warning: retaining old DataInstance.cpp_name {cpp_name}")
155+
if di.var_id != var_id:
156+
print(f"Warning: retaining old DataInstance.var_id {var_id}")
157+
if di.label != old.label:
158+
print(f"Warning: retaining old DataInstance.label {old.label}")
159+
160+
self.id_to_instance[var_id] = DataInstance(
161+
timestamp_np=di.timestamp_np,
162+
value_np=di.value_np,
163+
label=di.label,
164+
var_id=old.var_id,
165+
cpp_name=old.cpp_name,
166+
)
167+
78168
def __contains__(self, input_var_id_name: Union[str, int]) -> bool:
79169
"""
80170
Check if variable ID or variable name exists in the data.

perda/utils/preprocessing.py

Lines changed: 72 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -39,41 +39,6 @@
3939
PreprocessingStep = Callable[[SingleRunData], SingleRunData]
4040

4141

42-
def _replace(data: SingleRunData, cpp_name: str, new_values: NDArray[float64]) -> None:
43-
"""Overwrite the value_np of an existing DataInstance in-place (mutates dict)."""
44-
var_id = data.cpp_name_to_id[cpp_name]
45-
old = data.id_to_instance[var_id]
46-
data.id_to_instance[var_id] = DataInstance(
47-
timestamp_np=old.timestamp_np,
48-
value_np=new_values,
49-
label=old.label,
50-
var_id=old.var_id,
51-
cpp_name=old.cpp_name,
52-
)
53-
54-
55-
def _add(
56-
data: SingleRunData,
57-
cpp_name: str,
58-
label: str,
59-
timestamp_np: NDArray[float64],
60-
value_np: NDArray[float64],
61-
) -> None:
62-
"""Insert a new derived DataInstance using a synthetic negative ID."""
63-
synthetic_id = -(len(data.id_to_instance) + 1)
64-
di = DataInstance(
65-
timestamp_np=timestamp_np,
66-
value_np=value_np,
67-
label=label,
68-
var_id=synthetic_id,
69-
cpp_name=cpp_name,
70-
)
71-
data.id_to_instance[synthetic_id] = di
72-
data.cpp_name_to_id[cpp_name] = synthetic_id
73-
data.id_to_cpp_name[synthetic_id] = cpp_name
74-
data.id_to_descript[synthetic_id] = label
75-
76-
7742
def patch_ned_velocity(data: SingleRunData) -> SingleRunData:
7843
"""Correct a VectorNav bug where velocityBody.x/y/z contains NED instead of body-frame velocities.
7944
@@ -103,23 +68,45 @@ def patch_ned_velocity(data: SingleRunData) -> SingleRunData:
10368
[data[VECTORNAV_BODY_VEL_Y], data[VECTORNAV_BODY_VEL_Z], data[VECTORNAV_YAW]],
10469
)
10570

106-
vel_n = vel_n1.value_np
107-
vel_e = vel_e1.value_np
108-
vel_d = vel_d1.value_np
10971
yaw_rad = np.radians(yaw_deg.value_np)
11072

111-
_add(data, "velN", "NED North velocity (raw)", vel_n1.timestamp_np, vel_n.copy())
112-
_add(data, "velE", "NED East velocity (raw)", vel_e1.timestamp_np, vel_e.copy())
113-
_add(data, "velD", "NED Down velocity (raw)", vel_d1.timestamp_np, vel_d.copy())
73+
data["velN"] = DataInstance(
74+
timestamp_np=vel_n1.timestamp_np,
75+
value_np=vel_n1.value_np.copy(),
76+
label="NED North velocity (raw)",
77+
cpp_name="velN",
78+
)
79+
data["velE"] = DataInstance(
80+
timestamp_np=vel_e1.timestamp_np,
81+
value_np=vel_e1.value_np.copy(),
82+
label="NED East velocity (raw)",
83+
cpp_name="velE",
84+
)
85+
data["velD"] = DataInstance(
86+
timestamp_np=vel_d1.timestamp_np,
87+
value_np=vel_d1.value_np.copy(),
88+
label="NED Down velocity (raw)",
89+
cpp_name="velD",
90+
)
11491

11592
cos_y = np.cos(yaw_rad)
11693
sin_y = np.sin(yaw_rad)
117-
_replace(data, VECTORNAV_BODY_VEL_X, vel_n * cos_y + vel_e * sin_y) # forward
118-
_replace(data, VECTORNAV_BODY_VEL_Y, -vel_n * sin_y + vel_e * cos_y) # right
94+
data[VECTORNAV_BODY_VEL_X] = DataInstance(
95+
timestamp_np=vel_n1.timestamp_np,
96+
value_np=vel_n1.value_np * cos_y + vel_e1.value_np * sin_y,
97+
label=data[VECTORNAV_BODY_VEL_X].label,
98+
cpp_name=VECTORNAV_BODY_VEL_X,
99+
) # forward
100+
data[VECTORNAV_BODY_VEL_Y] = DataInstance(
101+
timestamp_np=vel_e1.timestamp_np,
102+
value_np=-vel_n1.value_np * sin_y + vel_e1.value_np * cos_y,
103+
label=data[VECTORNAV_BODY_VEL_Y].label,
104+
cpp_name=VECTORNAV_BODY_VEL_Y,
105+
) # right
119106
# vel_z (down) is identical in NED and FRD — no change needed
120107

121108
print(
122-
f"patch_ned_velocity: preserved raw NED in velN/velE/velD, rotated {len(vel_n)} points to body frame"
109+
f"patch_ned_velocity: preserved raw NED in velN/velE/velD, rotated {len(vel_n1)} points to body frame"
123110
)
124111
return data
125112

@@ -147,14 +134,18 @@ def convert_wheelspeeds_to_m_per_s(data: SingleRunData) -> SingleRunData:
147134
di = data[col]
148135
backup_name = col + "_mph"
149136
if backup_name not in data:
150-
_add(
151-
data,
152-
backup_name,
153-
(di.label or col) + " (mph backup)",
154-
di.timestamp_np,
155-
di.value_np.copy(),
137+
data[backup_name] = DataInstance(
138+
timestamp_np=di.timestamp_np,
139+
value_np=di.value_np.copy(),
140+
label=(di.label or col) + " (mph backup)",
141+
cpp_name=backup_name,
156142
)
157-
_replace(data, col, mph_to_m_per_s(di.value_np))
143+
data[col] = DataInstance(
144+
timestamp_np=di.timestamp_np,
145+
value_np=mph_to_m_per_s(di.value_np),
146+
label=di.label,
147+
cpp_name=col,
148+
)
158149

159150
print(
160151
f"convert_wheelspeeds_to_m_per_s: converted {len(cols)} channels mph → m/s, backups in *_mph"
@@ -182,27 +173,30 @@ def correct_motor_data(data: SingleRunData) -> SingleRunData:
182173

183174
backup_name = MOTOR_RPM + "_raw"
184175
if backup_name not in data:
185-
_add(
186-
data,
187-
backup_name,
188-
"Motor RPM raw (pre-flip)",
189-
di.timestamp_np,
190-
raw_rpm.copy(),
176+
data[backup_name] = DataInstance(
177+
timestamp_np=di.timestamp_np,
178+
value_np=raw_rpm.copy(),
179+
label="Motor RPM raw (pre-flip)",
180+
cpp_name=backup_name,
191181
)
192182

193183
flipped = -raw_rpm
194-
_replace(data, MOTOR_RPM, flipped)
184+
data[MOTOR_RPM] = DataInstance(
185+
timestamp_np=di.timestamp_np,
186+
value_np=flipped,
187+
label=di.label,
188+
cpp_name=MOTOR_RPM,
189+
)
195190

196191
tire_radius_m = in_to_m(TIRE_RADIUS_IN)
197192
wheel_speed: NDArray[float64] = (
198193
flipped * 2.0 * np.pi * tire_radius_m / (60.0 * GEAR_RATIO)
199194
)
200-
_add(
201-
data,
202-
MOTOR_WHEELSPEED,
203-
"Driven wheel speed from motor RPM (m/s)",
204-
di.timestamp_np,
205-
wheel_speed,
195+
data[MOTOR_WHEELSPEED] = DataInstance(
196+
timestamp_np=di.timestamp_np,
197+
value_np=wheel_speed,
198+
label="Driven wheel speed from motor RPM (m/s)",
199+
cpp_name=MOTOR_WHEELSPEED,
206200
)
207201

208202
print(
@@ -275,24 +269,19 @@ def __call__(
275269
backup_name = STEERING_ANGLE + "_original"
276270
if STEERING_ANGLE in data and backup_name not in data:
277271
orig = data[STEERING_ANGLE]
278-
_add(
279-
data,
280-
backup_name,
281-
(orig.label or STEERING_ANGLE) + " (original)",
282-
orig.timestamp_np,
283-
orig.value_np.copy(),
272+
data[backup_name] = DataInstance(
273+
timestamp_np=orig.timestamp_np,
274+
value_np=orig.value_np.copy(),
275+
label=(orig.label or STEERING_ANGLE) + " (original)",
276+
cpp_name=backup_name,
284277
)
285278

286-
if STEERING_ANGLE in data:
287-
_replace(data, STEERING_ANGLE, recomputed)
288-
else:
289-
_add(
290-
data,
291-
STEERING_ANGLE,
292-
"Steering angle recomputed from raw voltage (deg)",
293-
raw_di.timestamp_np,
294-
recomputed,
295-
)
279+
data[STEERING_ANGLE] = DataInstance(
280+
timestamp_np=raw_di.timestamp_np,
281+
value_np=recomputed,
282+
label="Steering angle recomputed from raw voltage (deg)",
283+
cpp_name=STEERING_ANGLE,
284+
)
296285

297286
cal_str = ", ".join(f"({v:.2f}V, {a:+.1f}°)" for v, a in self.pts)
298287
print(
@@ -302,6 +291,9 @@ def __call__(
302291
return data
303292

304293

294+
# This is the global callable that should actually be used in the preprocessing pipeline.
295+
# Either pass it directly, or with a custom calibration correct_steering_angle(calibration=...)
296+
# This should be treated like a function that supports partial application.
305297
correct_steering_angle = CorrectSteeringAngleLambda()
306298

307299

0 commit comments

Comments
 (0)