Skip to content

Commit 4fb7ebe

Browse files
authored
Merge pull request #169 from finsberg/conditional-type-error
Fix weird corner case in conditional when it is not able to simplify the expression to a boolean
2 parents f6aa8de + eaa28ac commit 4fb7ebe

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/gotranx/codegen/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def print_cond(cond):
8787
else:
8888
return printer._print(cond)
8989

90-
expr = sympy.simplify(expr)
91-
90+
try:
91+
expr = sympy.simplify(expr)
92+
except TypeError:
93+
logger.debug(f"Could not simplify expression {expr}")
9294
exprs = [printer._print(arg.expr) for arg in expr.args]
9395
conds = [print_cond(arg.cond) for arg in expr.args]
9496

tests/test_python_codegen.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,3 +824,33 @@ def test_python_jax_codegen_initial_parameter_values(codegen_jax: JaxCodeGenerat
824824
"\n return parameters"
825825
"\n"
826826
)
827+
828+
829+
def test_conditional_times_float(parser, trans):
830+
expr = """ \
831+
\nstates(x=0)
832+
\ndx_dt = (-1.0 - x) * Conditional(Lt(x, -1.0), 1.0,0.0) \
833+
"""
834+
835+
tree = parser.parse(expr)
836+
result = trans.transform(tree)
837+
ode = make_ode(*result, name="name")
838+
codegen = PythonCodeGenerator(ode)
839+
rhs = codegen.rhs()
840+
assert rhs == (
841+
"def rhs(t, states, parameters):"
842+
"\n"
843+
"\n # Assign states"
844+
"\n x = states[0]"
845+
"\n"
846+
"\n # Assign parameters"
847+
"\n"
848+
"\n # Assign expressions"
849+
"\n"
850+
"\n values = numpy.zeros_like(states, dtype=numpy.float64)"
851+
"\n dx_dt = (-x - 1.0) * numpy.where((x < -1.0), 1.0, 0.0)"
852+
"\n values[0] = dx_dt"
853+
"\n"
854+
"\n return values"
855+
"\n"
856+
)

0 commit comments

Comments
 (0)