Skip to content

Commit 6902342

Browse files
committed
Update testing.
Signed-off-by: James Goppert <james.goppert@gmail.com>
1 parent 2df0a52 commit 6902342

File tree

4 files changed

+136
-103
lines changed

4 files changed

+136
-103
lines changed

cyecca/planning/dubins.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,16 @@
66
77
Usage:
88
>>> from cyecca.planning import derive_dubins
9+
>>> import casadi as ca
910
>>> plan, eval_fn = derive_dubins()
11+
>>> p0, psi0 = ca.DM([0, 0]), 0.0 # Start position and heading
12+
>>> p1, psi1 = ca.DM([10, 10]), ca.pi/2 # End position and heading
13+
>>> R = 5.0 # Turn radius
1014
>>> cost, type, a1, d, a2, tp0, tp1, c0, c1 = plan(p0, psi0, p1, psi1, R)
15+
>>> s = 0.5 # Evaluation point along path (0 to 1)
1116
>>> x, y, psi = eval_fn(s, p0, psi0, a1, d, a2, tp0, tp1, c0, c1, R)
17+
>>> float(cost) > 0 # Path cost should be positive
18+
True
1219
1320
Functions:
1421
- derive_dubins() -> (dubins_fixedwing, dubins_eval)
@@ -395,10 +402,18 @@ def derive_dubins():
395402
)
396403

397404
# Pack cargo
398-
cargo_rsl = ca.vertcat(DubinsPathType.RSL, a1_rsl, d_rsl, a2_rsl, t0_rsl, t1_rsl, cr0, cl1)
399-
cargo_lsr = ca.vertcat(DubinsPathType.LSR, a1_lsr, d_lsr, a2_lsr, t0_lsr, t1_lsr, cl0, cr1)
400-
cargo_lsl = ca.vertcat(DubinsPathType.LSL, a1_lsl, d_lsl, a2_lsl, t0_lsl, t1_lsl, cl0, cl1)
401-
cargo_rsr = ca.vertcat(DubinsPathType.RSR, a1_rsr, d_rsr, a2_rsr, t0_rsr, t1_rsr, cr0, cr1)
405+
cargo_rsl = ca.vertcat(
406+
DubinsPathType.RSL, a1_rsl, d_rsl, a2_rsl, t0_rsl, t1_rsl, cr0, cl1
407+
)
408+
cargo_lsr = ca.vertcat(
409+
DubinsPathType.LSR, a1_lsr, d_lsr, a2_lsr, t0_lsr, t1_lsr, cl0, cr1
410+
)
411+
cargo_lsl = ca.vertcat(
412+
DubinsPathType.LSL, a1_lsl, d_lsl, a2_lsl, t0_lsl, t1_lsl, cl0, cl1
413+
)
414+
cargo_rsr = ca.vertcat(
415+
DubinsPathType.RSR, a1_rsr, d_rsr, a2_rsr, t0_rsr, t1_rsr, cr0, cr1
416+
)
402417

403418
min_cost, best_cargo = casadi_min_with_cargo(
404419
costs=[cost_rsl, cost_lsr, cost_lsl, cost_rsr],
@@ -481,7 +496,9 @@ def derive_dubins():
481496

482497
# Select segment
483498
in_arc1 = path_dist <= arc1_len
484-
in_straight = ca.logic_and(path_dist > arc1_len, path_dist <= arc1_len + straight_len)
499+
in_straight = ca.logic_and(
500+
path_dist > arc1_len, path_dist <= arc1_len + straight_len
501+
)
485502

486503
x_out = ca.if_else(in_arc1, x1, ca.if_else(in_straight, x2, x3))
487504
y_out = ca.if_else(in_arc1, y1, ca.if_else(in_straight, y2, y3))
@@ -595,7 +612,9 @@ def plot_dubins_path(p0, psi0, p1, psi1, R, plan, eval_fn, ax=None, n_points=200
595612

596613
# Draw circles
597614
for c in [c0, c1]:
598-
circ = plt.Circle((c[0], c[1]), R, fill=False, color="gray", alpha=0.6, linestyle="--")
615+
circ = plt.Circle(
616+
(c[0], c[1]), R, fill=False, color="gray", alpha=0.6, linestyle="--"
617+
)
599618
ax.add_patch(circ)
600619

601620
# Evaluate path

tests/test_dynamics.py

Lines changed: 71 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
import casadi as ca
44
import numpy as np
55
import pytest
6-
from cyecca.dynamics import ModelMX, ModelSX, input_var, output_var, param, state, symbolic
6+
from cyecca.dynamics import (
7+
ModelMX,
8+
ModelSX,
9+
input_var,
10+
output_var,
11+
param,
12+
state,
13+
symbolic,
14+
)
715
from cyecca.dynamics.composition import SubmodelProxy
816
from cyecca.dynamics.integrators import rk4, rk8, build_rk_integrator, integrate_n_steps
917

@@ -14,7 +22,7 @@ class TestQuickStart:
1422
def test_readme_quickstart_example(self):
1523
"""Verify the mass-spring-damper example from README works correctly."""
1624
# This is the exact code from the README Quick Start section
17-
25+
1826
@symbolic
1927
class States:
2028
x: ca.SX = state(1, 1.0, "position") # Start at x=1
@@ -44,35 +52,35 @@ class Outputs:
4452
# Output the full state
4553
f_y = ca.vertcat(x.x, x.v)
4654

47-
model.build(f_x=f_x, f_y=f_y, integrator='rk4')
55+
model.build(f_x=f_x, f_y=f_y, integrator="rk4")
4856

4957
# Simulate free oscillation from x0=1
5058
result = model.simulate(0.0, 10.0, 0.01)
51-
59+
5260
# Verify results match expected output
53-
final_position = result['x'][0, -1]
54-
final_velocity = result['x'][1, -1]
55-
61+
final_position = result["x"][0, -1]
62+
final_velocity = result["x"][1, -1]
63+
5664
# Check values are close to documented output
5765
assert abs(final_position - (-0.529209)) < 0.001
5866
assert abs(final_velocity - 0.323980) < 0.001
59-
67+
6068
# Verify we have the right number of timesteps
61-
assert len(result['t']) == 1001 # 0 to 10 with dt=0.01
62-
69+
assert len(result["t"]) == 1001 # 0 to 10 with dt=0.01
70+
6371
# Verify initial conditions
64-
assert result['x'][0, 0] == pytest.approx(1.0)
65-
assert result['x'][1, 0] == pytest.approx(0.0)
66-
72+
assert result["x"][0, 0] == pytest.approx(1.0)
73+
assert result["x"][1, 0] == pytest.approx(0.0)
74+
6775
# Verify oscillatory behavior (should cross zero at least once)
68-
x_pos = result['x'][0, :]
76+
x_pos = result["x"][0, :]
6977
sign_changes = np.sum(np.diff(np.sign(x_pos)) != 0)
7078
assert sign_changes >= 2 # At least one complete oscillation
71-
79+
7280
# Verify outputs match states
73-
assert 'out' in result
74-
assert np.allclose(result['out'][0, :], result['x'][0, :]) # position output
75-
assert np.allclose(result['out'][1, :], result['x'][1, :]) # velocity output
81+
assert "out" in result
82+
assert np.allclose(result["out"][0, :], result["x"][0, :]) # position output
83+
assert np.allclose(result["out"][1, :], result["x"][1, :]) # velocity output
7684

7785

7886
class TestModelCreate:
@@ -343,7 +351,8 @@ class EventIndicators:
343351
# Continuous state reset at event
344352
# Position: clamp to ground, velocity: reverse with energy loss
345353
f_m = ca.vertcat(
346-
0.0, -p.e * x.v # h+ = 0 (clamp to ground) # v+ = -e * v (reverse and reduce)
354+
0.0,
355+
-p.e * x.v, # h+ = 0 (clamp to ground) # v+ = -e * v (reverse and reduce)
347356
)
348357

349358
model.build(f_x=f_x, f_c=f_c, f_z=f_z, f_m=f_m, integrator="euler")
@@ -762,28 +771,28 @@ def test_rk4_simple_exponential_decay(self):
762771
x_sym = ca.SX.sym("x", 1)
763772
u_sym = ca.SX.sym("u", 0) # No inputs
764773
p_sym = ca.SX.sym("p", 1) # k parameter
765-
774+
766775
f_x = -p_sym * x_sym
767776
f = ca.Function("f", [x_sym, u_sym, p_sym], [f_x])
768-
777+
769778
# Create RK4 integrator with step size 0.1
770779
h = 0.1
771780
rk4_step = rk4(f, h)
772-
781+
773782
# Initial conditions
774783
x0 = ca.DM([1.0])
775784
u = ca.DM([])
776785
k = ca.DM([1.0])
777-
786+
778787
# Integrate for 10 steps (total time = 1.0)
779788
x = x0
780789
for _ in range(10):
781790
x = rk4_step(x, u, k)
782-
791+
783792
# Analytical solution: x(t) = x0 * exp(-k*t)
784793
t_final = 1.0
785794
x_analytical = float(x0) * np.exp(-float(k) * t_final)
786-
795+
787796
# Check accuracy (RK4 should be quite accurate)
788797
assert abs(float(x) - x_analytical) < 1e-6
789798

@@ -793,77 +802,77 @@ def test_rk4_with_substeps(self):
793802
x_sym = ca.SX.sym("x", 1)
794803
u_sym = ca.SX.sym("u", 0)
795804
p_sym = ca.SX.sym("p", 1)
796-
805+
797806
f_x = -p_sym * x_sym
798807
f = ca.Function("f", [x_sym, u_sym, p_sym], [f_x])
799-
808+
800809
# Create RK4 with 10 substeps
801810
h = 1.0
802811
rk4_step = rk4(f, h, N=10)
803-
812+
804813
x0 = ca.DM([1.0])
805814
u = ca.DM([])
806815
k = ca.DM([1.0])
807-
816+
808817
# Single step with substeps
809818
x_final = rk4_step(x0, u, k)
810-
819+
811820
# Analytical solution at t=1.0
812821
x_analytical = float(x0) * np.exp(-float(k) * 1.0)
813-
822+
814823
assert abs(float(x_final) - x_analytical) < 1e-6
815824

816825
def test_rk4_with_inputs(self):
817826
"""Test RK4 with inputs: dx/dt = u - k*x."""
818827
x_sym = ca.SX.sym("x", 1)
819828
u_sym = ca.SX.sym("u", 1)
820829
p_sym = ca.SX.sym("p", 1)
821-
830+
822831
f_x = u_sym - p_sym * x_sym
823832
f = ca.Function("f", [x_sym, u_sym, p_sym], [f_x])
824-
833+
825834
h = 0.1
826835
rk4_step = rk4(f, h)
827-
836+
828837
x0 = ca.DM([0.0])
829838
u = ca.DM([1.0]) # Constant input
830839
k = ca.DM([0.5])
831-
840+
832841
# Integrate for several steps
833842
x = x0
834843
for _ in range(20):
835844
x = rk4_step(x, u, k)
836-
845+
837846
# Analytical solution: x(t) = (u/k)*(1 - exp(-k*t)) for x0=0
838847
# With u=1, k=0.5, t=2.0: x = 2*(1 - exp(-1)) ≈ 1.264
839848
t_final = 2.0
840849
x_analytical = (float(u) / float(k)) * (1 - np.exp(-float(k) * t_final))
841-
850+
842851
assert abs(float(x) - x_analytical) < 1e-6
843852

844853
def test_rk8_exponential_decay(self):
845854
"""Test RK8 integrator on exponential decay."""
846855
x_sym = ca.SX.sym("x", 1)
847856
u_sym = ca.SX.sym("u", 0)
848857
p_sym = ca.SX.sym("p", 1)
849-
858+
850859
f_x = -p_sym * x_sym
851860
f = ca.Function("f", [x_sym, u_sym, p_sym], [f_x])
852-
861+
853862
# Use RK8 with default DOP853 tableau
854863
h = 0.5
855864
rk8_step = rk8(f, h)
856-
865+
857866
x0 = ca.DM([1.0])
858867
u = ca.DM([])
859868
k = ca.DM([1.0])
860-
869+
861870
# Single large step (RK8 should handle this well)
862871
x_final = rk8_step(x0, u, k)
863-
872+
864873
# Analytical solution at t=0.5
865874
x_analytical = float(x0) * np.exp(-float(k) * 0.5)
866-
875+
867876
# RK8 should be very accurate even with large step
868877
assert abs(float(x_final) - x_analytical) < 1e-8
869878

@@ -873,56 +882,52 @@ def test_integrate_n_steps(self):
873882
x_sym = ca.SX.sym("x", 1)
874883
u_sym = ca.SX.sym("u", 0)
875884
p_sym = ca.SX.sym("p", 1)
876-
885+
877886
f_x = -p_sym * x_sym
878887
f = ca.Function("f", [x_sym, u_sym, p_sym], [f_x])
879-
888+
880889
# Create one-step integrator
881890
h = 0.1
882891
rk4_step = rk4(f, h)
883-
892+
884893
# Create N-step rollout
885894
N = 10
886895
rollout = integrate_n_steps(rk4_step, ca.DM([1.0]), ca.DM([]), ca.DM([1.0]), N)
887-
896+
888897
# Execute rollout
889898
x0 = ca.DM([1.0])
890899
u = ca.DM([])
891900
k = ca.DM([1.0])
892-
901+
893902
x_final = rollout(x0, u, k)
894-
903+
895904
# Should match 10 steps of integration
896905
x_analytical = float(x0) * np.exp(-float(k) * 1.0)
897906
assert abs(float(x_final) - x_analytical) < 1e-6
898907

899908
def test_build_rk_integrator_custom_tableau(self):
900909
"""Test build_rk_integrator with a custom tableau."""
901910
# Define simple Euler method as a custom tableau
902-
euler_tableau = {
903-
"A": [[0.0]],
904-
"b": [1.0],
905-
"c": [0.0]
906-
}
907-
911+
euler_tableau = {"A": [[0.0]], "b": [1.0], "c": [0.0]}
912+
908913
x_sym = ca.SX.sym("x", 1)
909914
u_sym = ca.SX.sym("u", 0)
910915
p_sym = ca.SX.sym("p", 1)
911-
916+
912917
f_x = -p_sym * x_sym
913918
f = ca.Function("f", [x_sym, u_sym, p_sym], [f_x])
914-
919+
915920
h = 0.01
916921
euler_step = build_rk_integrator(f, h, euler_tableau, name="euler")
917-
922+
918923
# Take small steps with Euler method
919924
x = ca.DM([1.0])
920925
u = ca.DM([])
921926
k = ca.DM([1.0])
922-
927+
923928
for _ in range(100): # 100 steps of 0.01 = t=1.0
924929
x = euler_step(x, u, k)
925-
930+
926931
# Euler is less accurate but should be reasonable with small steps
927932
x_analytical = np.exp(-1.0)
928933
assert abs(float(x) - x_analytical) < 0.01
@@ -933,30 +938,30 @@ def test_rk4_multidimensional(self):
933938
x_sym = ca.SX.sym("x", 2) # [position, velocity]
934939
u_sym = ca.SX.sym("u", 0)
935940
p_sym = ca.SX.sym("p", 2) # [k, m]
936-
941+
937942
position = x_sym[0]
938943
velocity = x_sym[1]
939944
k = p_sym[0]
940945
m = p_sym[1]
941-
946+
942947
f_x = ca.vertcat(velocity, -k * position / m)
943948
f = ca.Function("f", [x_sym, u_sym, p_sym], [f_x])
944-
949+
945950
# Create integrator
946951
h = 0.01
947952
rk4_step = rk4(f, h)
948-
953+
949954
# Initial conditions: x=1, v=0
950955
x0 = ca.DM([1.0, 0.0])
951956
u = ca.DM([])
952957
params = ca.DM([1.0, 1.0]) # k=1, m=1 => omega=1
953-
958+
954959
# Integrate for one period (2*pi)
955960
n_steps = int(2 * np.pi / h)
956961
x = x0
957962
for _ in range(n_steps):
958963
x = rk4_step(x, u, params)
959-
964+
960965
# After one period, should return to initial position
961966
assert abs(float(x[0]) - 1.0) < 0.01
962967
assert abs(float(x[1]) - 0.0) < 0.01

0 commit comments

Comments
 (0)