Skip to content

Commit ab60c67

Browse files
authored
Merge pull request #172 from finsberg/generate-consistent-floats
Generate consistent floats
2 parents 091c2a9 + 6534099 commit ab60c67

File tree

7 files changed

+129
-19
lines changed

7 files changed

+129
-19
lines changed

src/gotranx/cli/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,11 @@ def ode2julia(
470470
1e-8,
471471
help="Delta value for the rush larsen schemes",
472472
),
473+
type_stable: bool = typer.Option(
474+
False,
475+
"--type-stable",
476+
help="Add T to the function signature",
477+
),
473478
# format: CFormat = typer.Option(
474479
# CFormat.clang_format,
475480
# "--format",
@@ -497,6 +502,7 @@ def ode2julia(
497502
verbose=verbose,
498503
stiff_states=stiff_states,
499504
delta=delta,
505+
type_stable=type_stable,
500506
)
501507

502508

src/gotranx/cli/gotran2julia.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_code(
2121
missing_values: dict[str, int] | None = None,
2222
delta: float = 1e-8,
2323
stiff_states: list[str] | None = None,
24+
type_stable: bool = False,
2425
) -> str:
2526
"""Generate the Julia code for the ODE
2627
@@ -41,13 +42,17 @@ def get_code(
4142
stiff_states : list[str] | None, optional
4243
Stiff states, by default None. Only applicable for
4344
the hybrid rush larsen scheme
45+
type_stable : bool, optional
46+
Add T to the function signature, by default False
4447
4548
Returns
4649
-------
4750
str
4851
The Julia code
4952
"""
50-
codegen = JuliaCodeGenerator(ode, remove_unused=remove_unused) # , format=Format.none)
53+
codegen = JuliaCodeGenerator(
54+
ode, remove_unused=remove_unused, type_stable=type_stable
55+
) # , format=Format.none)
5156
# formatter = get_formatter(format=format)
5257

5358
if missing_values is not None:
@@ -95,6 +100,7 @@ def main(
95100
missing_values: dict[str, int] | None = None,
96101
delta: float = 1e-8,
97102
stiff_states: list[str] | None = None,
103+
type_stable: bool = False,
98104
) -> None:
99105
loglevel = logging.DEBUG if verbose else logging.INFO
100106
structlog.configure(
@@ -109,6 +115,7 @@ def main(
109115
missing_values=missing_values,
110116
delta=delta,
111117
stiff_states=stiff_states,
118+
type_stable=type_stable,
112119
)
113120
out = fname if outname is None else Path(outname)
114121
out_name = out.with_suffix(suffix=".jl")

src/gotranx/codegen/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Func(typing.NamedTuple):
2525
values_type: str
2626
return_name: str = "values"
2727
num_return_values: int = 0
28+
post_function_signature: str = ""
2829

2930

3031
class RHSArgument(str, Enum):
@@ -332,6 +333,7 @@ def rhs(self, order: RHSArgument | str = RHSArgument.tsp, use_cse=False) -> str:
332333
shape_info="",
333334
values_type=rhs.values_type,
334335
missing_variables=missing_variables,
336+
post_function_signature=rhs.post_function_signature,
335337
)
336338

337339
return self._format(code)
@@ -400,6 +402,7 @@ def monitor_values(self, order: RHSArgument | str = RHSArgument.tsp, use_cse=Fal
400402
shape_info=shape_info,
401403
values_type="numpy.zeros(shape)",
402404
missing_variables=missing_variables,
405+
post_function_signature=rhs.post_function_signature,
403406
)
404407

405408
return self._format(code)
@@ -450,6 +453,7 @@ def missing_values(
450453
shape_info=shape_info,
451454
values_type="numpy.zeros(shape)",
452455
missing_variables=missing_variables,
456+
post_function_signature=rhs.post_function_signature,
453457
)
454458

455459
return self._format(code)
@@ -503,6 +507,7 @@ def scheme(self, f: schemes.scheme_func, order=SchemeArgument.stdp, **kwargs) ->
503507
shape_info="",
504508
values_type=rhs.values_type,
505509
missing_variables=missing_variables,
510+
post_function_signature=rhs.post_function_signature,
506511
)
507512
return self._format(code)
508513

src/gotranx/codegen/julia.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@ def bool_to_int(expr: str) -> str:
1616

1717

1818
class GotranJuliaCodePrinter(JuliaCodePrinter):
19-
def __init__(self, *args, **kwargs):
19+
def __init__(self, type_stable: bool = False, *args, **kwargs):
2020
super().__init__(*args, **kwargs)
21+
self._type_stable = type_stable
2122
self._settings["contract"] = False
2223

2324
def _print_Float(self, flt):
24-
return self._print(str(float(flt)))
25+
value = str(float(flt))
26+
if self._type_stable:
27+
return self._print(f"T({value})")
28+
return self._print(value)
2529

2630
def _print_Piecewise(self, expr):
2731
if isinstance(expr.args[0][0], Assignment):
@@ -56,9 +60,9 @@ def _print_Indexed(self, expr):
5660

5761

5862
class JuliaCodeGenerator(CodeGenerator):
59-
def __init__(self, ode: ODE, remove_unused: bool = False) -> None:
63+
def __init__(self, ode: ODE, remove_unused: bool = False, type_stable: bool = False) -> None:
6064
super().__init__(ode, remove_unused=remove_unused)
61-
self._printer = GotranJuliaCodePrinter()
65+
self._printer = GotranJuliaCodePrinter(type_stable=type_stable)
6266
# setattr(self, "_formatter", get_formatter(format=format))
6367

6468
@property
@@ -84,12 +88,23 @@ def _rhs_arguments(
8488
self, order: RHSArgument | str = RHSArgument.stp, const_states: bool = True
8589
) -> Func:
8690
value = RHSArgument.get_value(order)
87-
argument_dict = {
88-
"s": "states",
89-
"t": "t",
90-
"p": "parameters",
91-
}
92-
argument_list = [argument_dict[v] for v in value] + ["values"]
91+
if self._printer._type_stable:
92+
argument_dict = {
93+
"s": "states::AbstractVector{T}",
94+
"t": "t::T",
95+
"p": "parameters::AbstractVector{T}",
96+
}
97+
values = ["values::AbstractVector{T}"]
98+
post_function_signature = " where T"
99+
else:
100+
argument_dict = {
101+
"s": "states",
102+
"t": "t",
103+
"p": "parameters",
104+
}
105+
values = ["values"]
106+
post_function_signature = ""
107+
argument_list = [argument_dict[v] for v in value] + values
93108
states = sympy.IndexedBase("states", shape=(self.ode.num_states,), offset=1)
94109
parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,), offset=1)
95110
values = sympy.IndexedBase("values", shape=(self.ode.num_states,), offset=1)
@@ -100,6 +115,7 @@ def _rhs_arguments(
100115
parameters=parameters,
101116
values=values,
102117
values_type="",
118+
post_function_signature=post_function_signature,
103119
)
104120

105121
def _scheme_arguments(
@@ -108,13 +124,26 @@ def _scheme_arguments(
108124
const_states: bool = True,
109125
) -> Func:
110126
value = SchemeArgument.get_value(order)
111-
argument_dict = {
112-
"s": "states",
113-
"t": "t",
114-
"d": "dt",
115-
"p": "parameters",
116-
}
117-
argument_list = [argument_dict[v] for v in value] + ["values"]
127+
if self._printer._type_stable:
128+
argument_dict = {
129+
"s": "states::AbstractVector{T}",
130+
"t": "t::T",
131+
"d": "dt::T",
132+
"p": "parameters::AbstractVector{T}",
133+
}
134+
values = ["values::AbstractVector{T}"]
135+
post_function_signature = " where {T}"
136+
else:
137+
argument_dict = {
138+
"s": "states",
139+
"t": "t",
140+
"d": "dt",
141+
"p": "parameters",
142+
}
143+
values = ["values"]
144+
post_function_signature = ""
145+
146+
argument_list = [argument_dict[v] for v in value] + values
118147
states = sympy.IndexedBase("states", shape=(self.ode.num_states,), offset=1)
119148
parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,), offset=1)
120149
values = sympy.IndexedBase("values", shape=(self.ode.num_states,), offset=1)
@@ -125,4 +154,5 @@ def _scheme_arguments(
125154
parameters=parameters,
126155
values=values,
127156
values_type="",
157+
post_function_signature=post_function_signature,
128158
)

src/gotranx/templates/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def method(
151151
shape_info: str,
152152
values_type: str,
153153
missing_variables: str,
154+
post_function_signature: str,
154155
) -> str:
155156
"""The method function is a function that generates a method
156157
for the model.
@@ -183,6 +184,8 @@ def method(
183184
The type of the values
184185
missing_variables : str
185186
The code for handling missing variables
187+
post_function_signature : str
188+
The code going after the function signature (e.g. 'where T' or '-> None')
186189
187190
Returns
188191
-------

src/gotranx/templates/julia.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,16 @@ def method(
4545
values: str,
4646
return_name: None = None,
4747
num_return_values: int = 0,
48+
post_function_signature: str = "",
4849
**kwargs,
4950
):
5051
indent_states = indent(states, " ")
5152
indent_parameters = indent(parameters, " ")
5253
indent_values = indent(values, " ")
54+
5355
return dedent(
5456
f"""
55-
function {name}({args})
57+
function {name}({args}){post_function_signature}
5658
5759
# Assign states
5860
{indent_states}

tests/test_julia_codegen.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,60 @@ def test_julia_monitored(codegen: JuliaCodeGenerator):
191191
"\nend"
192192
"\n"
193193
)
194+
195+
196+
def test_consistent_floats(parser, trans):
197+
expr = """
198+
\nstates(x=0)
199+
\ndx_dt = Conditional(Ge(x, 31.4978), 1.0, 1.0763*exp(-1.007*exp(-0.0829*x)))
200+
"""
201+
202+
tree = parser.parse(expr)
203+
result = trans.transform(tree)
204+
ode = make_ode(*result, name="name")
205+
codegen = JuliaCodeGenerator(ode)
206+
rhs = codegen.rhs()
207+
208+
assert rhs == (
209+
"\nfunction rhs(t, states, parameters, values)"
210+
"\n"
211+
"\n # Assign states"
212+
"\n x = states[1]"
213+
"\n"
214+
"\n # Assign parameters"
215+
"\n"
216+
"\n"
217+
"\n # Assign expressions"
218+
"\n dx_dt = ((x >= 31.4978) ? (1.0) : (1.0763 * exp((-1.007) * exp((-0.0829) * x))))"
219+
"\n values[1] = dx_dt"
220+
"\nend"
221+
"\n"
222+
)
223+
224+
225+
def test_consistent_floats_with_T(parser, trans):
226+
expr = """
227+
\nstates(x=0)
228+
\ndx_dt = Conditional(Ge(x, 31.4978), 1.0, 1.0763*exp(-1.007*exp(-0.0829*x)))
229+
"""
230+
231+
tree = parser.parse(expr)
232+
result = trans.transform(tree)
233+
ode = make_ode(*result, name="name")
234+
codegen = JuliaCodeGenerator(ode, type_stable=True)
235+
rhs = codegen.rhs()
236+
assert rhs == (
237+
"\nfunction rhs(t::T, states::AbstractVector{T}, parameters::AbstractVector{T}, values::AbstractVector{T}) where T" # noqa: E501
238+
"\n"
239+
"\n # Assign states"
240+
"\n x = states[1]"
241+
"\n"
242+
"\n # Assign parameters"
243+
"\n"
244+
"\n"
245+
"\n # Assign expressions"
246+
"\n dx_dt = ((x >= T(31.4978)) ? (T(1.0)) : (T(1.0763) * exp((-T(1.007)) * exp((-T(0.0829)) * x))))" # noqa: E501
247+
"\n values[1] = dx_dt"
248+
"\nend"
249+
"\n"
250+
)

0 commit comments

Comments
 (0)