77
88import numpy as np
99import simplekml
10- from scipy import integrate
10+ from scipy . integrate import BDF , DOP853 , LSODA , RK23 , RK45 , OdeSolver , Radau
1111
1212from ..mathutils .function import Function , funcify_method
1313from ..mathutils .vector_matrix import Matrix , Vector
2424 quaternions_to_spin ,
2525)
2626
27+ ODE_SOLVER_MAP = {
28+ 'RK23' : RK23 ,
29+ 'RK45' : RK45 ,
30+ 'DOP853' : DOP853 ,
31+ 'Radau' : Radau ,
32+ 'BDF' : BDF ,
33+ 'LSODA' : LSODA ,
34+ }
2735
28- class Flight : # pylint: disable=too-many-public-methods
36+
37+ # pylint: disable=too-many-public-methods
38+ # pylint: disable=too-many-instance-attributes
39+ class Flight :
2940 """Keeps all flight information and has a method to simulate flight.
3041
3142 Attributes
@@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
506517 verbose = False ,
507518 name = "Flight" ,
508519 equations_of_motion = "standard" ,
520+ ode_solver = "LSODA" ,
509521 ):
510522 """Run a trajectory simulation.
511523
@@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
581593 more restricted set of equations of motion that only works for
582594 solid propulsion rockets. Such equations were used in RocketPy v0
583595 and are kept here for backwards compatibility.
596+ ode_solver : str, ``scipy.integrate.OdeSolver``, optional
597+ Integration method to use to solve the equations of motion ODE.
598+ Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF',
599+ 'LSODA' from ``scipy.integrate.solve_ivp``.
600+ Default is 'LSODA', which is recommended for most flights.
601+ A custom ``scipy.integrate.OdeSolver`` can be passed as well.
602+ For more information on the integration methods, see the scipy
603+ documentation [1]_.
604+
584605
585606 Returns
586607 -------
587608 None
609+
610+ References
611+ ----------
612+ .. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
588613 """
589614 # Save arguments
590615 self .env = environment
@@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
605630 self .terminate_on_apogee = terminate_on_apogee
606631 self .name = name
607632 self .equations_of_motion = equations_of_motion
633+ self .ode_solver = ode_solver
608634
609635 # Controller initialization
610636 self .__init_controllers ()
@@ -651,15 +677,16 @@ def __simulate(self, verbose):
651677
652678 # Create solver for this flight phase # TODO: allow different integrators
653679 self .function_evaluations .append (0 )
654- phase .solver = integrate .LSODA (
680+
681+ phase .solver = self ._solver (
655682 phase .derivative ,
656683 t0 = phase .t ,
657684 y0 = self .y_sol ,
658685 t_bound = phase .time_bound ,
659- min_step = self .min_time_step ,
660- max_step = self .max_time_step ,
661686 rtol = self .rtol ,
662687 atol = self .atol ,
688+ max_step = self .max_time_step ,
689+ min_step = self .min_time_step ,
663690 )
664691
665692 # Initialize phase time nodes
@@ -691,13 +718,14 @@ def __simulate(self, verbose):
691718 for node_index , node in self .time_iterator (phase .time_nodes ):
692719 # Determine time bound for this time node
693720 node .time_bound = phase .time_nodes [node_index + 1 ].t
694- # NOTE: Setting the time bound and status for the phase solver,
695- # and updating its internal state for the next integration step.
696721 phase .solver .t_bound = node .time_bound
697- phase .solver ._lsoda_solver ._integrator .rwork [0 ] = phase .solver .t_bound
698- phase .solver ._lsoda_solver ._integrator .call_args [4 ] = (
699- phase .solver ._lsoda_solver ._integrator .rwork
700- )
722+ if self .__is_lsoda :
723+ phase .solver ._lsoda_solver ._integrator .rwork [0 ] = (
724+ phase .solver .t_bound
725+ )
726+ phase .solver ._lsoda_solver ._integrator .call_args [4 ] = (
727+ phase .solver ._lsoda_solver ._integrator .rwork
728+ )
701729 phase .solver .status = "running"
702730
703731 # Feed required parachute and discrete controller triggers
@@ -1185,6 +1213,8 @@ def __init_solver_monitors(self):
11851213 self .t = self .solution [- 1 ][0 ]
11861214 self .y_sol = self .solution [- 1 ][1 :]
11871215
1216+ self .__set_ode_solver (self .ode_solver )
1217+
11881218 def __init_equations_of_motion (self ):
11891219 """Initialize equations of motion."""
11901220 if self .equations_of_motion == "solid_propulsion" :
@@ -1222,6 +1252,28 @@ def __cache_sensor_data(self):
12221252 sensor_data [sensor ] = sensor .measured_data [:]
12231253 self .sensor_data = sensor_data
12241254
1255+ def __set_ode_solver (self , solver ):
1256+ """Sets the ODE solver to be used in the simulation.
1257+
1258+ Parameters
1259+ ----------
1260+ solver : str, ``scipy.integrate.OdeSolver``
1261+ Integration method to use to solve the equations of motion ODE,
1262+ or a custom ``scipy.integrate.OdeSolver``.
1263+ """
1264+ if isinstance (solver , OdeSolver ):
1265+ self ._solver = solver
1266+ else :
1267+ try :
1268+ self ._solver = ODE_SOLVER_MAP [solver ]
1269+ except KeyError as e :
1270+ raise ValueError (
1271+ f"Invalid ``ode_solver`` input: { solver } . "
1272+ f"Available options are: { ', ' .join (ODE_SOLVER_MAP .keys ())} "
1273+ ) from e
1274+
1275+ self .__is_lsoda = hasattr (self ._solver , "_lsoda_solver" )
1276+
12251277 @cached_property
12261278 def effective_1rl (self ):
12271279 """Original rail length minus the distance measured from nozzle exit
0 commit comments