Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/gotranx/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

from .c import CCodeGenerator, GotranCCodePrinter
from .python import PythonCodeGenerator, GotranPythonCodePrinter
from .base import CodeGenerator, Func, RHSArgument, SchemeArgument
from .base import CodeGenerator, RHSFunc, SchemeFunc, RHSArgument, SchemeArgument
from .ode import GotranODECodePrinter, BaseGotranODECodePrinter

__all__ = [
"base",
"c",
"CCodeGenerator",
"CodeGenerator",
"Func",
"RHSFunc",
"SchemeFunc",
"RHSArgument",
"SchemeArgument",
"python",
Expand Down
75 changes: 63 additions & 12 deletions src/gotranx/codegen/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,44 @@
from .. import schemes


class Func(typing.NamedTuple):
class RHSFunc(typing.NamedTuple):
arguments: list[str]
states: sympy.IndexedBase
parameters: sympy.IndexedBase
values: sympy.IndexedBase
values_type: str
order: RHSArgument
return_name: str = "values"
num_return_values: int = 0
func: sympy.Function = sympy.Function("rhs")
t: sympy.Symbol = sympy.Symbol("t")

def __call__(
self, states: sympy.core.Basic, t: sympy.core.Basic, parameters: sympy.core.Basic
) -> sympy.core.Expr:
arg_list = [("s", states), ("t", t), ("p", parameters)]
sorted_args = [arg[1] for arg in sorted(arg_list, key=lambda x: self.order.index(x[0]))]
return self.func(*sorted_args)


class SchemeFunc(typing.NamedTuple):
arguments: list[str]
states: sympy.IndexedBase
parameters: sympy.IndexedBase
values: sympy.IndexedBase
values_type: str
order: SchemeArgument
return_name: str = "values"
num_return_values: int = 0
func: sympy.Function = sympy.Function("rhs")
t: sympy.Symbol = sympy.Symbol("t")

def __call__(
self, states: sympy.core.Basic, t: sympy.core.Basic, parameters: sympy.core.Basic
) -> sympy.core.Expr:
arg_list = [("states", states), ("t", t), ("parameters", parameters)]
sorted_args = [arg[1] for arg in sorted(arg_list, key=lambda x: self.arguments.index(x[0]))]
return self.func(*sorted_args)


class RHSArgument(str, Enum):
Expand Down Expand Up @@ -98,9 +128,11 @@ class CodeGenerator(abc.ABC):
def __init__(
self,
ode: ODE,
order: RHSArgument | str = RHSArgument.tsp,
remove_unused: bool = False,
) -> None:
self.ode = ode
self.order = RHSArgument[order]
self.remove_unused = remove_unused
self._missing_variables = ode.missing_variables

Expand Down Expand Up @@ -137,7 +169,12 @@ def _format(self, code: str) -> str:

return formatted_code

def _print_IndexedBase(self, lhs, rhs):
raise NotImplementedError

def _doprint(self, lhs, rhs, use_variable_prefix: bool = False) -> str:
if isinstance(lhs, sympy.IndexedBase):
return self._print_IndexedBase(lhs, rhs)
if use_variable_prefix:
return f"{self.variable_prefix}{self.printer.doprint(Assignment(lhs, rhs))}"
return self.printer.doprint(Assignment(lhs, rhs))
Expand Down Expand Up @@ -272,7 +309,7 @@ def _missing_variables_assignments(self):
)
return "\n".join(lst)

def rhs(self, order: RHSArgument | str = RHSArgument.tsp, use_cse=False) -> str:
def rhs(self, use_cse=False) -> str:
"""Generate code for the right hand side of the ODE

Parameters
Expand All @@ -288,7 +325,7 @@ def rhs(self, order: RHSArgument | str = RHSArgument.tsp, use_cse=False) -> str:
The generated code
"""

rhs = self._rhs_arguments(order)
rhs = self._rhs_arguments(self.order)
states = self._state_assignments(rhs.states, remove_unused=self.remove_unused)
parameters = self._parameter_assignments(rhs.parameters)
missing_variables = self._missing_variables_assignments()
Expand Down Expand Up @@ -443,18 +480,31 @@ def scheme(self, f: schemes.scheme_func, order=SchemeArgument.stdp, **kwargs) ->
kwargs : dict
Additional keyword arguments to be passed to the scheme function

Notes
-----
The scheme function should take the following arguments:
- ode: gotranx.ode.ODE
- dt: sympy.Symbol
- name: str
- printer: printer_func
- remove_unused: bool

and return a list of equations as strings that can be
formatted into a code snippet.

Returns
-------
str
The generated code
"""

rhs = self._scheme_arguments(order)
states = self._state_assignments(rhs.states, remove_unused=False)
parameters = self._parameter_assignments(rhs.parameters)
scheme = self._scheme_arguments(order)
rhs = self._rhs_arguments(self.order)
states = self._state_assignments(scheme.states, remove_unused=False)
parameters = self._parameter_assignments(scheme.parameters)
missing_variables = self._missing_variables_assignments()

arguments = rhs.arguments
arguments = scheme.arguments
if self._missing_variables:
arguments += ["missing_variables"]

Expand All @@ -465,6 +515,7 @@ def scheme(self, f: schemes.scheme_func, order=SchemeArgument.stdp, **kwargs) ->
name=rhs.return_name,
printer=self._doprint,
remove_unused=self.remove_unused,
rhs=rhs,
**kwargs,
)
values = "\n".join(eqs)
Expand All @@ -475,10 +526,10 @@ def scheme(self, f: schemes.scheme_func, order=SchemeArgument.stdp, **kwargs) ->
states=states,
parameters=parameters,
values=values,
return_name=rhs.return_name,
num_return_values=rhs.num_return_values,
return_name=scheme.return_name,
num_return_values=scheme.num_return_values,
shape_info="",
values_type=rhs.values_type,
values_type=scheme.values_type,
missing_variables=missing_variables,
)
return self._format(code)
Expand All @@ -492,7 +543,7 @@ def printer(self) -> CodePrinter: ...
def template(self) -> templates.Template: ...

@abc.abstractmethod
def _rhs_arguments(self, order: RHSArgument | str) -> Func: ...
def _rhs_arguments(self, order: RHSArgument | str) -> RHSFunc: ...

@abc.abstractmethod
def _scheme_arguments(self, order: SchemeArgument | str) -> Func: ...
def _scheme_arguments(self, order: SchemeArgument | str) -> SchemeFunc: ...
20 changes: 14 additions & 6 deletions src/gotranx/codegen/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..ode import ODE
from .. import templates
from .base import CodeGenerator, Func, RHSArgument, SchemeArgument
from .base import CodeGenerator, RHSFunc, SchemeFunc, RHSArgument, SchemeArgument


def bool_to_int(expr: str) -> str:
Expand All @@ -15,6 +15,7 @@ def bool_to_int(expr: str) -> str:
class GotranCCodePrinter(C99CodePrinter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._settings["allow_unknown_functions"] = True
self._settings["contract"] = False

def _print_Float(self, flt):
Expand Down Expand Up @@ -47,6 +48,9 @@ def _print_Piecewise(self, expr):

return value

def _print_IndexedBase(self, expr):
return "Hello"


class CCodeGenerator(CodeGenerator):
variable_prefix = "const double "
Expand Down Expand Up @@ -84,8 +88,10 @@ def imports(self) -> str:
)

def _rhs_arguments(
self, order: RHSArgument | str = RHSArgument.stp, const_states: bool = True
) -> Func:
self, order: RHSArgument | str | None = None, const_states: bool = True
) -> RHSFunc:
if order is None:
order = self.order
value = RHSArgument.get_value(order)
states_prefix = "const " if const_states else ""
argument_dict = {
Expand All @@ -98,19 +104,20 @@ def _rhs_arguments(
parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,))
values = sympy.IndexedBase("values", shape=(self.ode.num_states,))

return Func(
return RHSFunc(
arguments=argument_list,
states=states,
parameters=parameters,
values=values,
values_type="",
order=RHSArgument[order],
)

def _scheme_arguments(
self,
order: SchemeArgument | str = SchemeArgument.stdp,
const_states: bool = True,
) -> Func:
) -> SchemeFunc:
value = SchemeArgument.get_value(order)
states_prefix = "const " if const_states else ""
argument_dict = {
Expand All @@ -124,10 +131,11 @@ def _scheme_arguments(
parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,))
values = sympy.IndexedBase("values", shape=(self.ode.num_states,))

return Func(
return SchemeFunc(
arguments=argument_list,
states=states,
parameters=parameters,
values=values,
values_type="",
order=SchemeArgument[order],
)
23 changes: 17 additions & 6 deletions src/gotranx/codegen/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

from ..ode import ODE
from .. import templates
from .base import CodeGenerator, Func, RHSArgument, SchemeArgument, _print_Piecewise
from .base import CodeGenerator, RHSFunc, SchemeFunc, RHSArgument, SchemeArgument, _print_Piecewise


# class GotranPythonCodePrinter(NumPyPrinter):
class GotranPythonCodePrinter(PythonCodePrinter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._settings["allow_unknown_functions"] = True

_kf = {k: f"numpy.{v.replace('math.', '')}" for k, v in PythonCodePrinter._kf.items()}
_kc = {k: f"numpy.{v.replace('math.', '')}" for k, v in PythonCodePrinter._kc.items()}

Expand Down Expand Up @@ -96,6 +100,9 @@ def _print_Or(self, expr):

return value

# def _print_Function(self, expr):
# return "%s(%s)" % (self._print(expr.func), ", ".join(map(self._print, expr.args)))


class PythonCodeGenerator(CodeGenerator):
def __init__(self, ode: ODE, apply_black: bool = True, *args, **kwargs) -> None:
Expand Down Expand Up @@ -125,8 +132,10 @@ def imports(self) -> str:

def _rhs_arguments(
self,
order: RHSArgument | str = RHSArgument.stp,
) -> Func:
order: RHSArgument | str | None = None,
) -> RHSFunc:
if order is None:
order = self.order
value = RHSArgument.get_value(order)

argument_dict = {
Expand All @@ -140,20 +149,21 @@ def _rhs_arguments(
parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,))
values = sympy.IndexedBase("values", shape=(self.ode.num_states,))

return Func(
return RHSFunc(
arguments=argument_list,
states=states,
parameters=parameters,
values=values,
return_name="values",
num_return_values=self.ode.num_states,
values_type="numpy.zeros_like(states, dtype=numpy.float64)",
order=RHSArgument[order],
)

def _scheme_arguments(
self,
order: SchemeArgument | str = SchemeArgument.stdp,
) -> Func:
) -> SchemeFunc:
value = SchemeArgument.get_value(order)

argument_dict = {
Expand All @@ -168,12 +178,13 @@ def _scheme_arguments(
parameters = sympy.IndexedBase("parameters", shape=(self.ode.num_parameters,))
values = sympy.IndexedBase("values", shape=(self.ode.num_states,))

return Func(
return SchemeFunc(
arguments=argument_list,
states=states,
parameters=parameters,
values=values,
return_name="values",
num_return_values=self.ode.num_states,
values_type="numpy.zeros_like(states, dtype=numpy.float64)",
order=SchemeArgument[order],
)
8 changes: 7 additions & 1 deletion src/gotranx/schemes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import typing
from types import CodeType
import typing

import sympy
from structlog import get_logger
Expand All @@ -10,6 +10,9 @@
from . import sympytools
from ._enum import DeprecatedEnum

if typing.TYPE_CHECKING:
from .codegen.base import RHSFunc

logger = get_logger()


Expand Down Expand Up @@ -40,6 +43,7 @@ def __call__(
self,
ode: ODE,
dt: sympy.Symbol,
rhs: RHSFunc,
name: str = "values",
printer: printer_func = default_printer,
remove_unused: bool = False,
Expand Down Expand Up @@ -114,6 +118,7 @@ def fraction_numerator_is_nonzero(expr):
def explicit_euler(
ode: ODE,
dt: sympy.Symbol,
rhs: RHSFunc,
name: str = "values",
printer: printer_func = default_printer,
remove_unused: bool = False,
Expand Down Expand Up @@ -166,6 +171,7 @@ def explicit_euler(
def generalized_rush_larsen(
ode: ODE,
dt: sympy.Symbol,
rhs: RHSFunc,
name: str = "values",
printer: printer_func = default_printer,
remove_unused: bool = False,
Expand Down