@@ -16,12 +16,16 @@ def bool_to_int(expr: str) -> str:
1616
1717
1818class GotranJuliaCodePrinter (JuliaCodePrinter ):
19- def __init__ (self , * args , ** kwargs ):
19+ def __init__ (self , type_stable : bool = False , * args , ** kwargs ):
2020 super ().__init__ (* args , ** kwargs )
21+ self ._type_stable = type_stable
2122 self ._settings ["contract" ] = False
2223
2324 def _print_Float (self , flt ):
24- return self ._print (str (float (flt )))
25+ value = str (float (flt ))
26+ if self ._type_stable :
27+ return self ._print (f"T({ value } )" )
28+ return self ._print (value )
2529
2630 def _print_Piecewise (self , expr ):
2731 if isinstance (expr .args [0 ][0 ], Assignment ):
@@ -56,9 +60,9 @@ def _print_Indexed(self, expr):
5660
5761
5862class JuliaCodeGenerator (CodeGenerator ):
59- def __init__ (self , ode : ODE , remove_unused : bool = False ) -> None :
63+ def __init__ (self , ode : ODE , remove_unused : bool = False , type_stable : bool = False ) -> None :
6064 super ().__init__ (ode , remove_unused = remove_unused )
61- self ._printer = GotranJuliaCodePrinter ()
65+ self ._printer = GotranJuliaCodePrinter (type_stable = type_stable )
6266 # setattr(self, "_formatter", get_formatter(format=format))
6367
6468 @property
@@ -84,12 +88,23 @@ def _rhs_arguments(
8488 self , order : RHSArgument | str = RHSArgument .stp , const_states : bool = True
8589 ) -> Func :
8690 value = RHSArgument .get_value (order )
87- argument_dict = {
88- "s" : "states" ,
89- "t" : "t" ,
90- "p" : "parameters" ,
91- }
92- argument_list = [argument_dict [v ] for v in value ] + ["values" ]
91+ if self ._printer ._type_stable :
92+ argument_dict = {
93+ "s" : "states::AbstractVector{T}" ,
94+ "t" : "t::T" ,
95+ "p" : "parameters::AbstractVector{T}" ,
96+ }
97+ values = ["values::AbstractVector{T}" ]
98+ post_function_signature = " where T"
99+ else :
100+ argument_dict = {
101+ "s" : "states" ,
102+ "t" : "t" ,
103+ "p" : "parameters" ,
104+ }
105+ values = ["values" ]
106+ post_function_signature = ""
107+ argument_list = [argument_dict [v ] for v in value ] + values
93108 states = sympy .IndexedBase ("states" , shape = (self .ode .num_states ,), offset = 1 )
94109 parameters = sympy .IndexedBase ("parameters" , shape = (self .ode .num_parameters ,), offset = 1 )
95110 values = sympy .IndexedBase ("values" , shape = (self .ode .num_states ,), offset = 1 )
@@ -100,6 +115,7 @@ def _rhs_arguments(
100115 parameters = parameters ,
101116 values = values ,
102117 values_type = "" ,
118+ post_function_signature = post_function_signature ,
103119 )
104120
105121 def _scheme_arguments (
@@ -108,13 +124,26 @@ def _scheme_arguments(
108124 const_states : bool = True ,
109125 ) -> Func :
110126 value = SchemeArgument .get_value (order )
111- argument_dict = {
112- "s" : "states" ,
113- "t" : "t" ,
114- "d" : "dt" ,
115- "p" : "parameters" ,
116- }
117- argument_list = [argument_dict [v ] for v in value ] + ["values" ]
127+ if self ._printer ._type_stable :
128+ argument_dict = {
129+ "s" : "states::AbstractVector{T}" ,
130+ "t" : "t::T" ,
131+ "d" : "dt::T" ,
132+ "p" : "parameters::AbstractVector{T}" ,
133+ }
134+ values = ["values::AbstractVector{T}" ]
135+ post_function_signature = " where {T}"
136+ else :
137+ argument_dict = {
138+ "s" : "states" ,
139+ "t" : "t" ,
140+ "d" : "dt" ,
141+ "p" : "parameters" ,
142+ }
143+ values = ["values" ]
144+ post_function_signature = ""
145+
146+ argument_list = [argument_dict [v ] for v in value ] + values
118147 states = sympy .IndexedBase ("states" , shape = (self .ode .num_states ,), offset = 1 )
119148 parameters = sympy .IndexedBase ("parameters" , shape = (self .ode .num_parameters ,), offset = 1 )
120149 values = sympy .IndexedBase ("values" , shape = (self .ode .num_states ,), offset = 1 )
@@ -125,4 +154,5 @@ def _scheme_arguments(
125154 parameters = parameters ,
126155 values = values ,
127156 values_type = "" ,
157+ post_function_signature = post_function_signature ,
128158 )
0 commit comments