Skip to content

Commit 33fb946

Browse files
authored
Merge pull request #8 from duckietown/feat/lockstep-timeout-rk4-DTSW-7714
feat: PX4Multirotor performance and lockstep improvements DTSW-7714
2 parents ad14dcc + 449ecc1 commit 33fb946

File tree

2 files changed

+100
-32
lines changed

2 files changed

+100
-32
lines changed

rotorpy/vehicles/multirotor.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,15 @@ def __init__(self, quad_params, initial_state = {'x': np.array([0,0,0]),
179179

180180
self.aero = aero
181181

182-
# Integrator settings.
182+
# Integrator settings.
183183
if integrator_kwargs is None:
184184
self.integrator_kwargs = {'method':'RK45'}
185185
else:
186186
self.integrator_kwargs = integrator_kwargs
187187

188+
# Fixed-step RK4 option (much faster than solve_ivp for small timesteps)
189+
self.use_fixed_step = False
190+
188191
def extract_geometry(self):
189192
"""
190193
Extracts the geometry in self.rotors for efficient use later on in the computation of
@@ -233,19 +236,22 @@ def step(self, state, control, t_step):
233236
# The true motor speeds can not fall below min and max speeds.
234237
cmd_rotor_speeds = np.clip(cmd_rotor_speeds, self.rotor_speed_min, self.rotor_speed_max)
235238

236-
# Form autonomous ODE for constant inputs and integrate one time step.
237-
def s_dot_fn(t, s):
238-
return self._s_dot_fn(t, s, cmd_rotor_speeds)
239239
s = Multirotor._pack_state(state)
240240

241-
# Integrate
242-
sol = scipy.integrate.solve_ivp(
243-
s_dot_fn,
244-
(0.0, t_step),
245-
s,
246-
**self.integrator_kwargs
247-
)
248-
s = sol['y'][:, -1]
241+
if self.use_fixed_step:
242+
# Fixed-step RK4: 4 function evaluations, no adaptive overhead
243+
s = self._rk4_step(s, cmd_rotor_speeds, t_step)
244+
else:
245+
# Adaptive RK45 via scipy
246+
def s_dot_fn(t, s):
247+
return self._s_dot_fn(t, s, cmd_rotor_speeds)
248+
sol = scipy.integrate.solve_ivp(
249+
s_dot_fn,
250+
(0.0, t_step),
251+
s,
252+
**self.integrator_kwargs
253+
)
254+
s = sol['y'][:, -1]
249255

250256
# Unpack the state vector.
251257
state = Multirotor._unpack_state(s)
@@ -263,6 +269,17 @@ def s_dot_fn(t, s):
263269

264270
return state
265271

272+
def _rk4_step(self, s, cmd_rotor_speeds, dt):
273+
"""
274+
Single fixed-step RK4 integration. 7x faster than solve_ivp for small
275+
timesteps with identical accuracy at dt <= 4ms.
276+
"""
277+
k1 = self._s_dot_fn(0, s, cmd_rotor_speeds)
278+
k2 = self._s_dot_fn(0, s + 0.5 * dt * k1, cmd_rotor_speeds)
279+
k3 = self._s_dot_fn(0, s + 0.5 * dt * k2, cmd_rotor_speeds)
280+
k4 = self._s_dot_fn(0, s + dt * k3, cmd_rotor_speeds)
281+
return s + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
282+
266283
def _s_dot_fn(self, t, s, cmd_rotor_speeds):
267284
"""
268285
Compute derivative of state for quadrotor given fixed control inputs as

rotorpy/vehicles/px4_multirotor.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import Tuple
77

88
import math
9+
import time
10+
import statistics
911

1012
# Constants
1113
R_EARTH = 6378137.0 # meters
@@ -55,6 +57,7 @@ def __init__(
5557
mavlink_url="tcpin:localhost:4560",
5658
autopilot_controller=True,
5759
lockstep=True,
60+
lockstep_timeout=0.002,
5861
integrator_kwargs=None
5962
):
6063
integrator_kwargs = integrator_kwargs if integrator_kwargs is not None else {'method':'RK45', 'rtol':1e-2, 'atol':1e-4, 'max_step':0.05}
@@ -76,18 +79,22 @@ def __init__(
7679
enable_ground=enable_ground,
7780
integrator_kwargs=integrator_kwargs
7881
)
82+
# Use fixed-step RK4 for faster physics (7x vs solve_ivp, identical accuracy at dt<=4ms)
83+
self.use_fixed_step = True
7984
# Simulated IMU (with noise)
8085
self.imu = Imu()
8186
self._enable_imu_noise = True # Always add a bit of noise to avoid stale detection
8287
self.t = 0.0
8388

89+
8490
print("PX4Multirotor: Initializing MAVLink connection... on {}".format(mavlink_url))
8591
self.conn = mavutil.mavlink_connection(mavlink_url)
8692
self.conn.wait_heartbeat()
8793
print("PX4Multirotor: MAVLink connection established.")
8894

8995
self._autopilot_controller = autopilot_controller
9096
self._lockstep_enabled = lockstep
97+
self._lockstep_timeout = lockstep_timeout
9198
self._last_control = {'cmd_motor_speeds': np.zeros(quad_params['num_rotors'])}
9299

93100
@staticmethod
@@ -166,10 +173,20 @@ def geodetic_to_mavlink(lat_deg: float, lon_deg: float, alt_msl_m: float) -> Tup
166173

167174
def _fetch_latest_px4_control(self, blocking : bool = True):
168175
"""Fetch the latest HIL_ACTUATOR_CONTROLS message from PX4 and update control inputs."""
169-
170-
msg = self.conn.recv_match(type='HIL_ACTUATOR_CONTROLS', blocking=blocking, timeout=0.01)
171-
if msg is not None:
172-
return {'cmd_motor_speeds': [c * self.rotor_speed_max for c in msg.controls[:self.num_rotors]]}
176+
# Drain all queued messages non-blocking first
177+
latest = None
178+
while True:
179+
msg = self.conn.recv_match(type='HIL_ACTUATOR_CONTROLS', blocking=False)
180+
if msg is None:
181+
break
182+
latest = msg
183+
# If no message found and blocking requested, poll with retry until timeout
184+
if latest is None and blocking:
185+
deadline = time.perf_counter() + self._lockstep_timeout
186+
while latest is None and time.perf_counter() < deadline:
187+
latest = self.conn.recv_match(type='HIL_ACTUATOR_CONTROLS', blocking=True, timeout=0.01)
188+
if latest is not None:
189+
return {'cmd_motor_speeds': [c * self.rotor_speed_max for c in latest.controls[:self.num_rotors]]}
173190

174191
def _enu_to_ned_cmps(self, v_enu):
175192
v_n = float(v_enu[1])
@@ -190,25 +207,22 @@ def _imu(self, state, statedot):
190207
omega_frd = np.array([omega_flu[0], -omega_flu[1], -omega_flu[2]], dtype=float)
191208
return a_frd, omega_frd
192209

193-
def _send_hil_state_quaternion(self, state, statedot):
210+
def _send_hil_state_quaternion(self, state, a_frd_gt):
194211
"""
195212
Send HIL_STATE_QUATERNION message to PX4.
196-
213+
197214
Args:
198215
state: Current vehicle state
199-
statedot: State derivative (from the `Multirotor.statedot` method)
216+
a_frd_gt: Ground-truth acceleration in FRD frame (precomputed)
200217
"""
201218
# Convert cartesian ENU position to geodetic coordinates (latitude, longitude and height)
202219
lat_deg, lon_deg, height_meters = self.enu_to_geodetic(*state['x'])
203220
lat_e7, lon_e7, alt_mm = self.geodetic_to_mavlink(lat_deg, lon_deg, height_meters)
204-
221+
205222
# Convert quaternion from rotorpy to aerospace convention using ArduPilot's method
206223
quaternion_flu2ned = Ardupilot._quaternion_rotorpy_to_aerospace(state['q'])
207224
vx_cms, vy_cms, vz_cms = self._enu_to_ned_cmps(state['v'])
208225

209-
# Send the ground truth acceleration in the state message (without imu noise)
210-
a_flu_gt = self.imu.measurement(state, statedot, with_noise=False)["accel"]
211-
a_frd_gt = np.array([a_flu_gt[0], -a_flu_gt[1], -a_flu_gt[2]], dtype=float)
212226
a_frd_mg = np.clip(np.round(a_frd_gt / 9.80665 * 1000.0), INT_MIN, INT_MAX).astype(np.int16)
213227

214228
self.conn.mav.hil_state_quaternion_send(
@@ -253,17 +267,15 @@ def _baro(self, state):
253267
temperature_c = T0 - 0.0065 * alt_m - 273.15
254268
return abs_pressure_hpa, pressure_alt_m, temperature_c
255269

256-
def _send_hil_sensor(self, state, statedot):
270+
def _send_hil_sensor(self, state, a_frd, omega_frd):
257271
"""
258272
Send HIL_SENSOR message to PX4.
259273
260274
Args:
261275
state: Current vehicle state
262-
statedot: State derivative (computed externally)
276+
a_frd: Accelerometer reading in FRD frame (precomputed)
277+
omega_frd: Gyroscope reading in FRD frame (precomputed)
263278
"""
264-
# Get IMU measurements
265-
a_frd, omega_frd = self._imu(state, statedot)
266-
267279
# Magnetometer: Earth field rotated into body FRD frame (gauss)
268280
mag_frd = self._mag_body_frd(state)
269281

@@ -286,12 +298,21 @@ def _send_hil_sensor(self, state, statedot):
286298
)
287299

288300
def step(self, state, control, t_step):
301+
_t0 = time.perf_counter()
289302

290-
# Compute state derivative once for state and messages
291-
# and send both HIL messages
303+
# Compute state derivative once for messages
292304
statedot = self.statedot(state, control, 0.0)
293-
self._send_hil_state_quaternion(state, statedot)
294-
self._send_hil_sensor(state, statedot)
305+
306+
# Compute IMU measurements once (noisy for HIL_SENSOR, ground-truth for HIL_STATE)
307+
a_frd_noisy, omega_frd_noisy = self._imu(state, statedot)
308+
a_flu_gt = self.imu.measurement(state, statedot, with_noise=False)["accel"]
309+
a_frd_gt = np.array([a_flu_gt[0], -a_flu_gt[1], -a_flu_gt[2]], dtype=float)
310+
_t_statedot = time.perf_counter()
311+
312+
# Send both HIL messages with precomputed data
313+
self._send_hil_state_quaternion(state, a_frd_gt)
314+
self._send_hil_sensor(state, a_frd_noisy, omega_frd_noisy)
315+
_t_hil_send = time.perf_counter()
295316

296317
# Use PX4 commands only if autopilot_controller is True
297318
if self._autopilot_controller:
@@ -303,9 +324,39 @@ def step(self, state, control, t_step):
303324

304325
else: # In this case we use the control provided by the external controller
305326
self._last_control = control
327+
_t_px4_fetch = time.perf_counter()
306328

307329
state = super().step(state, self._last_control, t_step)
308330
self.state = state
309331
self.t += t_step
332+
_t_rk4 = time.perf_counter()
333+
334+
# Accumulate timing samples; print summary every 500 steps (~2s at 250Hz)
335+
if not hasattr(self, '_step_timing'):
336+
self._step_timing = {'statedot': [], 'hil_send': [], 'px4_fetch': [], 'rk4': [], 'total': []}
337+
self._step_count = 0
338+
self._step_count += 1
339+
self._step_timing['statedot'].append((_t_statedot - _t0) * 1e3)
340+
self._step_timing['hil_send'].append((_t_hil_send - _t_statedot) * 1e3)
341+
self._step_timing['px4_fetch'].append((_t_px4_fetch - _t_hil_send) * 1e3)
342+
self._step_timing['rk4'].append((_t_rk4 - _t_px4_fetch) * 1e3)
343+
self._step_timing['total'].append((_t_rk4 - _t0) * 1e3)
344+
345+
if self._step_count % 500 == 0:
346+
def _fmt(vals):
347+
avg = statistics.mean(vals)
348+
p99 = sorted(vals)[int(len(vals) * 0.99)]
349+
return f"avg={avg:.2f}ms p99={p99:.2f}ms"
350+
t = self._step_timing
351+
print(
352+
f"[PX4Multirotor.step timing | n={self._step_count}]\n"
353+
f" statedot : {_fmt(t['statedot'])}\n"
354+
f" hil_send : {_fmt(t['hil_send'])}\n"
355+
f" px4_fetch: {_fmt(t['px4_fetch'])} ← lockstep wait\n"
356+
f" rk4 : {_fmt(t['rk4'])}\n"
357+
f" total : {_fmt(t['total'])}"
358+
)
359+
for key in self._step_timing:
360+
self._step_timing[key].clear()
310361

311362
return state

0 commit comments

Comments
 (0)