diff --git a/src/gotranx/codegen/__init__.py b/src/gotranx/codegen/__init__.py index edb72de..70513de 100644 --- a/src/gotranx/codegen/__init__.py +++ b/src/gotranx/codegen/__init__.py @@ -5,7 +5,7 @@ from .c import CCodeGenerator, GotranCCodePrinter from .python import PythonCodeGenerator, GotranPythonCodePrinter -from .base import CodeGenerator, Func, RHSArgument, SchemeArgument +from .base import CodeGenerator, RHSFunc, SchemeFunc, RHSArgument, SchemeArgument from .ode import GotranODECodePrinter, BaseGotranODECodePrinter __all__ = [ @@ -13,7 +13,8 @@ "c", "CCodeGenerator", "CodeGenerator", - "Func", + "RHSFunc", + "SchemeFunc", "RHSArgument", "SchemeArgument", "python", diff --git a/src/gotranx/codegen/base.py b/src/gotranx/codegen/base.py index 06d5c1e..ed48b4b 100644 --- a/src/gotranx/codegen/base.py +++ b/src/gotranx/codegen/base.py @@ -14,14 +14,44 @@ from .. import schemes -class Func(typing.NamedTuple): +class RHSFunc(typing.NamedTuple): arguments: list[str] states: sympy.IndexedBase parameters: sympy.IndexedBase values: sympy.IndexedBase values_type: str + order: RHSArgument return_name: str = "values" num_return_values: int = 0 + func: sympy.Function = sympy.Function("rhs") + t: sympy.Symbol = sympy.Symbol("t") + + def __call__( + self, states: sympy.core.Basic, t: sympy.core.Basic, parameters: sympy.core.Basic + ) -> sympy.core.Expr: + arg_list = [("s", states), ("t", t), ("p", parameters)] + sorted_args = [arg[1] for arg in sorted(arg_list, key=lambda x: self.order.index(x[0]))] + return self.func(*sorted_args) + + +class SchemeFunc(typing.NamedTuple): + arguments: list[str] + states: sympy.IndexedBase + parameters: sympy.IndexedBase + values: sympy.IndexedBase + values_type: str + order: SchemeArgument + return_name: str = "values" + num_return_values: int = 0 + func: sympy.Function = sympy.Function("rhs") + t: sympy.Symbol = sympy.Symbol("t") + + def __call__( + self, states: sympy.core.Basic, t: sympy.core.Basic, parameters: sympy.core.Basic + ) -> sympy.core.Expr: + arg_list = [("states", states), ("t", t), ("parameters", parameters)] + sorted_args = [arg[1] for arg in sorted(arg_list, key=lambda x: self.arguments.index(x[0]))] + return self.func(*sorted_args) class RHSArgument(str, Enum): @@ -98,9 +128,11 @@ class CodeGenerator(abc.ABC): def __init__( self, ode: ODE, + order: RHSArgument | str = RHSArgument.tsp, remove_unused: bool = False, ) -> None: self.ode = ode + self.order = RHSArgument[order] self.remove_unused = remove_unused self._missing_variables = ode.missing_variables @@ -137,7 +169,12 @@ def _format(self, code: str) -> str: return formatted_code + def _print_IndexedBase(self, lhs, rhs): + raise NotImplementedError + def _doprint(self, lhs, rhs, use_variable_prefix: bool = False) -> str: + if isinstance(lhs, sympy.IndexedBase): + return self._print_IndexedBase(lhs, rhs) if use_variable_prefix: return f"{self.variable_prefix}{self.printer.doprint(Assignment(lhs, rhs))}" return self.printer.doprint(Assignment(lhs, rhs)) @@ -272,7 +309,7 @@ def _missing_variables_assignments(self): ) return "\n".join(lst) - def rhs(self, order: RHSArgument | str = RHSArgument.tsp, use_cse=False) -> str: + def rhs(self, use_cse=False) -> str: """Generate code for the right hand side of the ODE Parameters @@ -288,7 +325,7 @@ def rhs(self, order: RHSArgument | str = RHSArgument.tsp, use_cse=False) -> str: The generated code """ - rhs = self._rhs_arguments(order) + rhs = self._rhs_arguments(self.order) states = self._state_assignments(rhs.states, remove_unused=self.remove_unused) parameters = self._parameter_assignments(rhs.parameters) missing_variables = self._missing_variables_assignments() @@ -443,18 +480,31 @@ def scheme(self, f: schemes.scheme_func, order=SchemeArgument.stdp, **kwargs) -> kwargs : dict Additional keyword arguments to be passed to the scheme function + Notes + ----- + The scheme function should take the following arguments: + - ode: gotranx.ode.ODE + - dt: sympy.Symbol + - name: str + - printer: printer_func + - remove_unused: bool + + and return a list of equations as strings that can be + formatted into a code snippet. + Returns ------- str The generated code """ - rhs = self._scheme_arguments(order) - states = self._state_assignments(rhs.states, remove_unused=False) - parameters = self._parameter_assignments(rhs.parameters) + scheme = self._scheme_arguments(order) + rhs = self._rhs_arguments(self.order) + states = self._state_assignments(scheme.states, remove_unused=False) + parameters = self._parameter_assignments(scheme.parameters) missing_variables = self._missing_variables_assignments() - arguments = rhs.arguments + arguments = scheme.arguments if self._missing_variables: arguments += ["missing_variables"] @@ -465,6 +515,7 @@ def scheme(self, f: schemes.scheme_func, order=SchemeArgument.stdp, **kwargs) -> name=rhs.return_name, printer=self._doprint, remove_unused=self.remove_unused, + rhs=rhs, **kwargs, ) values = "\n".join(eqs) @@ -475,10 +526,10 @@ def scheme(self, f: schemes.scheme_func, order=SchemeArgument.stdp, **kwargs) -> states=states, parameters=parameters, values=values, - return_name=rhs.return_name, - num_return_values=rhs.num_return_values, + return_name=scheme.return_name, + num_return_values=scheme.num_return_values, shape_info="", - values_type=rhs.values_type, + values_type=scheme.values_type, missing_variables=missing_variables, ) return self._format(code) @@ -492,7 +543,7 @@ def printer(self) -> CodePrinter: ... def template(self) -> templates.Template: ... @abc.abstractmethod - def _rhs_arguments(self, order: RHSArgument | str) -> Func: ... + def _rhs_arguments(self, order: RHSArgument | str) -> RHSFunc: ... @abc.abstractmethod - def _scheme_arguments(self, order: SchemeArgument | str) -> Func: ... + def _scheme_arguments(self, order: SchemeArgument | str) -> SchemeFunc: ... diff --git a/src/gotranx/codegen/c.py b/src/gotranx/codegen/c.py index d0dd13a..a5b5403 100644 --- a/src/gotranx/codegen/c.py +++ b/src/gotranx/codegen/c.py @@ -5,7 +5,7 @@ from ..ode import ODE from .. import templates -from .base import CodeGenerator, Func, RHSArgument, SchemeArgument +from .base import CodeGenerator, RHSFunc, SchemeFunc, RHSArgument, SchemeArgument def bool_to_int(expr: str) -> str: @@ -15,6 +15,7 @@ def bool_to_int(expr: str) -> str: class GotranCCodePrinter(C99CodePrinter): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._settings["allow_unknown_functions"] = True self._settings["contract"] = False def _print_Float(self, flt): @@ -47,6 +48,9 @@ def _print_Piecewise(self, expr): return value + def _print_IndexedBase(self, expr): + return "Hello" + class CCodeGenerator(CodeGenerator): variable_prefix = "const double " @@ -84,8 +88,10 @@ def imports(self) -> str: ) def _rhs_arguments( - self, order: RHSArgument | str = RHSArgument.stp, const_states: bool = True - ) -> Func: + self, order: RHSArgument | str | None = None, const_states: bool = True + ) -> RHSFunc: + if order is None: + order = self.order value = RHSArgument.get_value(order) states_prefix = "const " if const_states else "" argument_dict = { @@ -98,19 +104,20 @@ def _rhs_arguments( parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,)) values = sympy.IndexedBase("values", shape=(self.ode.num_states,)) - return Func( + return RHSFunc( arguments=argument_list, states=states, parameters=parameters, values=values, values_type="", + order=RHSArgument[order], ) def _scheme_arguments( self, order: SchemeArgument | str = SchemeArgument.stdp, const_states: bool = True, - ) -> Func: + ) -> SchemeFunc: value = SchemeArgument.get_value(order) states_prefix = "const " if const_states else "" argument_dict = { @@ -124,10 +131,11 @@ def _scheme_arguments( parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,)) values = sympy.IndexedBase("values", shape=(self.ode.num_states,)) - return Func( + return SchemeFunc( arguments=argument_list, states=states, parameters=parameters, values=values, values_type="", + order=SchemeArgument[order], ) diff --git a/src/gotranx/codegen/python.py b/src/gotranx/codegen/python.py index 806b0f9..610b895 100644 --- a/src/gotranx/codegen/python.py +++ b/src/gotranx/codegen/python.py @@ -8,11 +8,15 @@ from ..ode import ODE from .. import templates -from .base import CodeGenerator, Func, RHSArgument, SchemeArgument, _print_Piecewise +from .base import CodeGenerator, RHSFunc, SchemeFunc, RHSArgument, SchemeArgument, _print_Piecewise # class GotranPythonCodePrinter(NumPyPrinter): class GotranPythonCodePrinter(PythonCodePrinter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._settings["allow_unknown_functions"] = True + _kf = {k: f"numpy.{v.replace('math.', '')}" for k, v in PythonCodePrinter._kf.items()} _kc = {k: f"numpy.{v.replace('math.', '')}" for k, v in PythonCodePrinter._kc.items()} @@ -96,6 +100,9 @@ def _print_Or(self, expr): return value + # def _print_Function(self, expr): + # return "%s(%s)" % (self._print(expr.func), ", ".join(map(self._print, expr.args))) + class PythonCodeGenerator(CodeGenerator): def __init__(self, ode: ODE, apply_black: bool = True, *args, **kwargs) -> None: @@ -125,8 +132,10 @@ def imports(self) -> str: def _rhs_arguments( self, - order: RHSArgument | str = RHSArgument.stp, - ) -> Func: + order: RHSArgument | str | None = None, + ) -> RHSFunc: + if order is None: + order = self.order value = RHSArgument.get_value(order) argument_dict = { @@ -140,7 +149,7 @@ def _rhs_arguments( parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,)) values = sympy.IndexedBase("values", shape=(self.ode.num_states,)) - return Func( + return RHSFunc( arguments=argument_list, states=states, parameters=parameters, @@ -148,12 +157,13 @@ def _rhs_arguments( return_name="values", num_return_values=self.ode.num_states, values_type="numpy.zeros_like(states, dtype=numpy.float64)", + order=RHSArgument[order], ) def _scheme_arguments( self, order: SchemeArgument | str = SchemeArgument.stdp, - ) -> Func: + ) -> SchemeFunc: value = SchemeArgument.get_value(order) argument_dict = { @@ -168,7 +178,7 @@ def _scheme_arguments( parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,)) values = sympy.IndexedBase("values", shape=(self.ode.num_states,)) - return Func( + return SchemeFunc( arguments=argument_list, states=states, parameters=parameters, @@ -176,4 +186,5 @@ def _scheme_arguments( return_name="values", num_return_values=self.ode.num_states, values_type="numpy.zeros_like(states, dtype=numpy.float64)", + order=SchemeArgument[order], ) diff --git a/src/gotranx/schemes.py b/src/gotranx/schemes.py index 184efa2..07cffaf 100644 --- a/src/gotranx/schemes.py +++ b/src/gotranx/schemes.py @@ -1,6 +1,6 @@ from __future__ import annotations -import typing from types import CodeType +import typing import sympy from structlog import get_logger @@ -10,6 +10,9 @@ from . import sympytools from ._enum import DeprecatedEnum +if typing.TYPE_CHECKING: + from .codegen.base import RHSFunc + logger = get_logger() @@ -40,6 +43,7 @@ def __call__( self, ode: ODE, dt: sympy.Symbol, + rhs: RHSFunc, name: str = "values", printer: printer_func = default_printer, remove_unused: bool = False, @@ -114,6 +118,7 @@ def fraction_numerator_is_nonzero(expr): def explicit_euler( ode: ODE, dt: sympy.Symbol, + rhs: RHSFunc, name: str = "values", printer: printer_func = default_printer, remove_unused: bool = False, @@ -166,6 +171,7 @@ def explicit_euler( def generalized_rush_larsen( ode: ODE, dt: sympy.Symbol, + rhs: RHSFunc, name: str = "values", printer: printer_func = default_printer, remove_unused: bool = False,