33import casadi as ca
44import numpy as np
55import pytest
6- from cyecca .dynamics import ModelMX , ModelSX , input_var , output_var , param , state , symbolic
6+ from cyecca .dynamics import (
7+ ModelMX ,
8+ ModelSX ,
9+ input_var ,
10+ output_var ,
11+ param ,
12+ state ,
13+ symbolic ,
14+ )
715from cyecca .dynamics .composition import SubmodelProxy
816from cyecca .dynamics .integrators import rk4 , rk8 , build_rk_integrator , integrate_n_steps
917
@@ -14,7 +22,7 @@ class TestQuickStart:
1422 def test_readme_quickstart_example (self ):
1523 """Verify the mass-spring-damper example from README works correctly."""
1624 # This is the exact code from the README Quick Start section
17-
25+
1826 @symbolic
1927 class States :
2028 x : ca .SX = state (1 , 1.0 , "position" ) # Start at x=1
@@ -44,35 +52,35 @@ class Outputs:
4452 # Output the full state
4553 f_y = ca .vertcat (x .x , x .v )
4654
47- model .build (f_x = f_x , f_y = f_y , integrator = ' rk4' )
55+ model .build (f_x = f_x , f_y = f_y , integrator = " rk4" )
4856
4957 # Simulate free oscillation from x0=1
5058 result = model .simulate (0.0 , 10.0 , 0.01 )
51-
59+
5260 # Verify results match expected output
53- final_position = result ['x' ][0 , - 1 ]
54- final_velocity = result ['x' ][1 , - 1 ]
55-
61+ final_position = result ["x" ][0 , - 1 ]
62+ final_velocity = result ["x" ][1 , - 1 ]
63+
5664 # Check values are close to documented output
5765 assert abs (final_position - (- 0.529209 )) < 0.001
5866 assert abs (final_velocity - 0.323980 ) < 0.001
59-
67+
6068 # Verify we have the right number of timesteps
61- assert len (result ['t' ]) == 1001 # 0 to 10 with dt=0.01
62-
69+ assert len (result ["t" ]) == 1001 # 0 to 10 with dt=0.01
70+
6371 # Verify initial conditions
64- assert result ['x' ][0 , 0 ] == pytest .approx (1.0 )
65- assert result ['x' ][1 , 0 ] == pytest .approx (0.0 )
66-
72+ assert result ["x" ][0 , 0 ] == pytest .approx (1.0 )
73+ assert result ["x" ][1 , 0 ] == pytest .approx (0.0 )
74+
6775 # Verify oscillatory behavior (should cross zero at least once)
68- x_pos = result ['x' ][0 , :]
76+ x_pos = result ["x" ][0 , :]
6977 sign_changes = np .sum (np .diff (np .sign (x_pos )) != 0 )
7078 assert sign_changes >= 2 # At least one complete oscillation
71-
79+
7280 # Verify outputs match states
73- assert ' out' in result
74- assert np .allclose (result [' out' ][0 , :], result ['x' ][0 , :]) # position output
75- assert np .allclose (result [' out' ][1 , :], result ['x' ][1 , :]) # velocity output
81+ assert " out" in result
82+ assert np .allclose (result [" out" ][0 , :], result ["x" ][0 , :]) # position output
83+ assert np .allclose (result [" out" ][1 , :], result ["x" ][1 , :]) # velocity output
7684
7785
7886class TestModelCreate :
@@ -343,7 +351,8 @@ class EventIndicators:
343351 # Continuous state reset at event
344352 # Position: clamp to ground, velocity: reverse with energy loss
345353 f_m = ca .vertcat (
346- 0.0 , - p .e * x .v # h+ = 0 (clamp to ground) # v+ = -e * v (reverse and reduce)
354+ 0.0 ,
355+ - p .e * x .v , # h+ = 0 (clamp to ground) # v+ = -e * v (reverse and reduce)
347356 )
348357
349358 model .build (f_x = f_x , f_c = f_c , f_z = f_z , f_m = f_m , integrator = "euler" )
@@ -762,28 +771,28 @@ def test_rk4_simple_exponential_decay(self):
762771 x_sym = ca .SX .sym ("x" , 1 )
763772 u_sym = ca .SX .sym ("u" , 0 ) # No inputs
764773 p_sym = ca .SX .sym ("p" , 1 ) # k parameter
765-
774+
766775 f_x = - p_sym * x_sym
767776 f = ca .Function ("f" , [x_sym , u_sym , p_sym ], [f_x ])
768-
777+
769778 # Create RK4 integrator with step size 0.1
770779 h = 0.1
771780 rk4_step = rk4 (f , h )
772-
781+
773782 # Initial conditions
774783 x0 = ca .DM ([1.0 ])
775784 u = ca .DM ([])
776785 k = ca .DM ([1.0 ])
777-
786+
778787 # Integrate for 10 steps (total time = 1.0)
779788 x = x0
780789 for _ in range (10 ):
781790 x = rk4_step (x , u , k )
782-
791+
783792 # Analytical solution: x(t) = x0 * exp(-k*t)
784793 t_final = 1.0
785794 x_analytical = float (x0 ) * np .exp (- float (k ) * t_final )
786-
795+
787796 # Check accuracy (RK4 should be quite accurate)
788797 assert abs (float (x ) - x_analytical ) < 1e-6
789798
@@ -793,77 +802,77 @@ def test_rk4_with_substeps(self):
793802 x_sym = ca .SX .sym ("x" , 1 )
794803 u_sym = ca .SX .sym ("u" , 0 )
795804 p_sym = ca .SX .sym ("p" , 1 )
796-
805+
797806 f_x = - p_sym * x_sym
798807 f = ca .Function ("f" , [x_sym , u_sym , p_sym ], [f_x ])
799-
808+
800809 # Create RK4 with 10 substeps
801810 h = 1.0
802811 rk4_step = rk4 (f , h , N = 10 )
803-
812+
804813 x0 = ca .DM ([1.0 ])
805814 u = ca .DM ([])
806815 k = ca .DM ([1.0 ])
807-
816+
808817 # Single step with substeps
809818 x_final = rk4_step (x0 , u , k )
810-
819+
811820 # Analytical solution at t=1.0
812821 x_analytical = float (x0 ) * np .exp (- float (k ) * 1.0 )
813-
822+
814823 assert abs (float (x_final ) - x_analytical ) < 1e-6
815824
816825 def test_rk4_with_inputs (self ):
817826 """Test RK4 with inputs: dx/dt = u - k*x."""
818827 x_sym = ca .SX .sym ("x" , 1 )
819828 u_sym = ca .SX .sym ("u" , 1 )
820829 p_sym = ca .SX .sym ("p" , 1 )
821-
830+
822831 f_x = u_sym - p_sym * x_sym
823832 f = ca .Function ("f" , [x_sym , u_sym , p_sym ], [f_x ])
824-
833+
825834 h = 0.1
826835 rk4_step = rk4 (f , h )
827-
836+
828837 x0 = ca .DM ([0.0 ])
829838 u = ca .DM ([1.0 ]) # Constant input
830839 k = ca .DM ([0.5 ])
831-
840+
832841 # Integrate for several steps
833842 x = x0
834843 for _ in range (20 ):
835844 x = rk4_step (x , u , k )
836-
845+
837846 # Analytical solution: x(t) = (u/k)*(1 - exp(-k*t)) for x0=0
838847 # With u=1, k=0.5, t=2.0: x = 2*(1 - exp(-1)) ≈ 1.264
839848 t_final = 2.0
840849 x_analytical = (float (u ) / float (k )) * (1 - np .exp (- float (k ) * t_final ))
841-
850+
842851 assert abs (float (x ) - x_analytical ) < 1e-6
843852
844853 def test_rk8_exponential_decay (self ):
845854 """Test RK8 integrator on exponential decay."""
846855 x_sym = ca .SX .sym ("x" , 1 )
847856 u_sym = ca .SX .sym ("u" , 0 )
848857 p_sym = ca .SX .sym ("p" , 1 )
849-
858+
850859 f_x = - p_sym * x_sym
851860 f = ca .Function ("f" , [x_sym , u_sym , p_sym ], [f_x ])
852-
861+
853862 # Use RK8 with default DOP853 tableau
854863 h = 0.5
855864 rk8_step = rk8 (f , h )
856-
865+
857866 x0 = ca .DM ([1.0 ])
858867 u = ca .DM ([])
859868 k = ca .DM ([1.0 ])
860-
869+
861870 # Single large step (RK8 should handle this well)
862871 x_final = rk8_step (x0 , u , k )
863-
872+
864873 # Analytical solution at t=0.5
865874 x_analytical = float (x0 ) * np .exp (- float (k ) * 0.5 )
866-
875+
867876 # RK8 should be very accurate even with large step
868877 assert abs (float (x_final ) - x_analytical ) < 1e-8
869878
@@ -873,56 +882,52 @@ def test_integrate_n_steps(self):
873882 x_sym = ca .SX .sym ("x" , 1 )
874883 u_sym = ca .SX .sym ("u" , 0 )
875884 p_sym = ca .SX .sym ("p" , 1 )
876-
885+
877886 f_x = - p_sym * x_sym
878887 f = ca .Function ("f" , [x_sym , u_sym , p_sym ], [f_x ])
879-
888+
880889 # Create one-step integrator
881890 h = 0.1
882891 rk4_step = rk4 (f , h )
883-
892+
884893 # Create N-step rollout
885894 N = 10
886895 rollout = integrate_n_steps (rk4_step , ca .DM ([1.0 ]), ca .DM ([]), ca .DM ([1.0 ]), N )
887-
896+
888897 # Execute rollout
889898 x0 = ca .DM ([1.0 ])
890899 u = ca .DM ([])
891900 k = ca .DM ([1.0 ])
892-
901+
893902 x_final = rollout (x0 , u , k )
894-
903+
895904 # Should match 10 steps of integration
896905 x_analytical = float (x0 ) * np .exp (- float (k ) * 1.0 )
897906 assert abs (float (x_final ) - x_analytical ) < 1e-6
898907
899908 def test_build_rk_integrator_custom_tableau (self ):
900909 """Test build_rk_integrator with a custom tableau."""
901910 # Define simple Euler method as a custom tableau
902- euler_tableau = {
903- "A" : [[0.0 ]],
904- "b" : [1.0 ],
905- "c" : [0.0 ]
906- }
907-
911+ euler_tableau = {"A" : [[0.0 ]], "b" : [1.0 ], "c" : [0.0 ]}
912+
908913 x_sym = ca .SX .sym ("x" , 1 )
909914 u_sym = ca .SX .sym ("u" , 0 )
910915 p_sym = ca .SX .sym ("p" , 1 )
911-
916+
912917 f_x = - p_sym * x_sym
913918 f = ca .Function ("f" , [x_sym , u_sym , p_sym ], [f_x ])
914-
919+
915920 h = 0.01
916921 euler_step = build_rk_integrator (f , h , euler_tableau , name = "euler" )
917-
922+
918923 # Take small steps with Euler method
919924 x = ca .DM ([1.0 ])
920925 u = ca .DM ([])
921926 k = ca .DM ([1.0 ])
922-
927+
923928 for _ in range (100 ): # 100 steps of 0.01 = t=1.0
924929 x = euler_step (x , u , k )
925-
930+
926931 # Euler is less accurate but should be reasonable with small steps
927932 x_analytical = np .exp (- 1.0 )
928933 assert abs (float (x ) - x_analytical ) < 0.01
@@ -933,30 +938,30 @@ def test_rk4_multidimensional(self):
933938 x_sym = ca .SX .sym ("x" , 2 ) # [position, velocity]
934939 u_sym = ca .SX .sym ("u" , 0 )
935940 p_sym = ca .SX .sym ("p" , 2 ) # [k, m]
936-
941+
937942 position = x_sym [0 ]
938943 velocity = x_sym [1 ]
939944 k = p_sym [0 ]
940945 m = p_sym [1 ]
941-
946+
942947 f_x = ca .vertcat (velocity , - k * position / m )
943948 f = ca .Function ("f" , [x_sym , u_sym , p_sym ], [f_x ])
944-
949+
945950 # Create integrator
946951 h = 0.01
947952 rk4_step = rk4 (f , h )
948-
953+
949954 # Initial conditions: x=1, v=0
950955 x0 = ca .DM ([1.0 , 0.0 ])
951956 u = ca .DM ([])
952957 params = ca .DM ([1.0 , 1.0 ]) # k=1, m=1 => omega=1
953-
958+
954959 # Integrate for one period (2*pi)
955960 n_steps = int (2 * np .pi / h )
956961 x = x0
957962 for _ in range (n_steps ):
958963 x = rk4_step (x , u , params )
959-
964+
960965 # After one period, should return to initial position
961966 assert abs (float (x [0 ]) - 1.0 ) < 0.01
962967 assert abs (float (x [1 ]) - 0.0 ) < 0.01
0 commit comments