Skip to content

Commit 2bde260

Browse files
authored
Merge pull request #221 from jtgreen/jtg/julia-mtk
Jtg/julia mtk
2 parents 8bdd38c + 5d1339a commit 2bde260

File tree

7 files changed

+371
-4
lines changed

7 files changed

+371
-4
lines changed

src/gotranx/cli/__init__.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..schemes import Scheme, get_scheme
99
from ..codegen import PythonFormat, CFormat
1010
from ..codegen.base import Shape
11-
from . import gotran2c, gotran2py, gotran2julia, gotran2md
11+
from . import gotran2c, gotran2py, gotran2julia, gotran2md, gotran2mtk
1212
from . import utils
1313

1414
app = typer.Typer()
@@ -515,6 +515,69 @@ def ode2julia(
515515
)
516516

517517

518+
@app.command()
519+
def ode2mtk(
520+
fname: typing.Optional[Path] = typer.Argument(
521+
None,
522+
exists=True,
523+
file_okay=True,
524+
dir_okay=False,
525+
writable=False,
526+
readable=True,
527+
resolve_path=True,
528+
),
529+
outname: typing.Optional[str] = typer.Option(
530+
None,
531+
"-o",
532+
"--outname",
533+
help="Output name",
534+
),
535+
remove_unused: bool = typer.Option(
536+
False,
537+
"--remove-unused",
538+
help="Remove unused variables",
539+
),
540+
version: bool = typer.Option(
541+
None,
542+
"--version",
543+
callback=version_callback,
544+
is_eager=True,
545+
help="Show version",
546+
),
547+
license: bool = typer.Option(
548+
None,
549+
"--license",
550+
callback=license_callback,
551+
is_eager=True,
552+
help="Show license",
553+
),
554+
config: typing.Optional[Path] = typer.Option(
555+
None,
556+
"-c",
557+
"--config",
558+
help="Read configuration options from a configuration file",
559+
),
560+
verbose: bool = typer.Option(
561+
False,
562+
"--verbose",
563+
"-v",
564+
help="Verbose output",
565+
),
566+
):
567+
if fname is None:
568+
return typer.echo("No file specified")
569+
570+
config_data = utils.read_config(config)
571+
verbose = config_data.get("verbose", verbose)
572+
remove_unused = config_data.get("remove_unused", remove_unused)
573+
gotran2mtk.main(
574+
fname=fname,
575+
outname=outname,
576+
remove_unused=remove_unused,
577+
verbose=verbose,
578+
)
579+
580+
518581
@app.command()
519582
def list_schemes():
520583
from rich.console import Console

src/gotranx/cli/gotran2mtk.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
import logging
5+
import structlog
6+
7+
from ..codegen.mtk import MTKCodeGenerator
8+
from ..load import load_ode
9+
from ..ode import ODE
10+
11+
logger = structlog.get_logger()
12+
13+
14+
def get_code(
15+
ode: ODE,
16+
remove_unused: bool = False,
17+
) -> str:
18+
"""Generate ModelingToolkit.jl code for the ODE."""
19+
codegen = MTKCodeGenerator(ode, remove_unused=remove_unused)
20+
return codegen.generate()
21+
22+
23+
def main(
24+
fname: Path,
25+
outname: str | None = None,
26+
remove_unused: bool = False,
27+
verbose: bool = False,
28+
) -> None:
29+
loglevel = logging.DEBUG if verbose else logging.INFO
30+
structlog.configure(
31+
wrapper_class=structlog.make_filtering_bound_logger(loglevel),
32+
)
33+
ode = load_ode(fname)
34+
code = get_code(
35+
ode,
36+
remove_unused=remove_unused,
37+
)
38+
out = fname if outname is None else Path(outname)
39+
out_name = out.with_suffix(suffix=".jl")
40+
out_name.write_text(code)
41+
logger.info(f"Wrote {out_name}")

src/gotranx/codegen/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .base import CodeGenerator, Func, RHSArgument, SchemeArgument
1010
from .ode import GotranODECodePrinter, BaseGotranODECodePrinter
1111
from .julia import JuliaCodeGenerator, GotranJuliaCodePrinter
12+
from .mtk import MTKCodeGenerator
1213

1314
__all__ = [
1415
"base",
@@ -30,6 +31,5 @@
3031
"JuliaCodeGenerator",
3132
"GotranJuliaCodePrinter",
3233
"JaxCodeGenerator",
33-
"JuliaCodeGenerator",
34-
"GotranJuliaCodePrinter",
34+
"MTKCodeGenerator",
3535
]

src/gotranx/codegen/mtk.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations
2+
3+
from typing import Iterable
4+
5+
import sympy
6+
import structlog
7+
8+
from .base import CodeGenerator, Func, RHSArgument, SchemeArgument
9+
from .julia import GotranJuliaCodePrinter
10+
from .. import templates
11+
from ..ode import ODE
12+
from .. import atoms
13+
14+
logger = structlog.get_logger()
15+
16+
17+
class MTKCodeGenerator(CodeGenerator):
18+
def __init__(self, ode: ODE, remove_unused: bool = False) -> None:
19+
super().__init__(ode, remove_unused=remove_unused)
20+
self._printer = GotranJuliaCodePrinter()
21+
22+
@property
23+
def printer(self):
24+
return self._printer
25+
26+
@property
27+
def template(self):
28+
return templates.mtk
29+
30+
# The following methods satisfy the abstract interface but are not used
31+
def _rhs_arguments(self, order: RHSArgument | str = RHSArgument.tsp, const_states: bool = True):
32+
dummy = sympy.IndexedBase("dummy")
33+
return Func(
34+
arguments=[],
35+
states=dummy,
36+
parameters=dummy,
37+
values=dummy,
38+
values_type="",
39+
)
40+
41+
def _scheme_arguments(
42+
self,
43+
order: SchemeArgument | str = SchemeArgument.stdp,
44+
const_states: bool = True,
45+
):
46+
dummy = sympy.IndexedBase("dummy")
47+
return Func(
48+
arguments=[],
49+
states=dummy,
50+
parameters=dummy,
51+
values=dummy,
52+
values_type="",
53+
)
54+
55+
def _format_expr(self, expr) -> str:
56+
return self.printer.doprint(expr)
57+
58+
def _observed(self, assignments: Iterable[atoms.Assignment]) -> list[str]:
59+
observed = []
60+
for assignment in assignments:
61+
if isinstance(assignment, atoms.Intermediate):
62+
observed.append(f"{assignment.name} ~ {self._format_expr(assignment.expr)}")
63+
return observed
64+
65+
def _equations(self, derivatives: Iterable[atoms.StateDerivative]) -> list[str]:
66+
equations = []
67+
for derivative in derivatives:
68+
lhs = f"D({derivative.state.name})"
69+
rhs = self._format_expr(derivative.expr)
70+
equations.append(f"{lhs} ~ {rhs}")
71+
return equations
72+
73+
def generate(self, remove_unused: bool | None = None) -> str:
74+
"""Generate ModelingToolkit.jl code."""
75+
if remove_unused is None:
76+
remove_unused = self.remove_unused
77+
78+
param_names = [p.name for p in self.ode.parameters if self._condition(p.name)]
79+
missing_names = [n for n in self.ode.missing_variables.keys() if self._condition(n)]
80+
extra_params = [n for n in missing_names if n not in param_names]
81+
82+
state_names = [s.name for s in self.ode.sorted_states() if self._condition(s.name)]
83+
84+
assignments = self.ode.sorted_assignments(remove_unused=remove_unused)
85+
observed = self._observed(assignments)
86+
derivatives = [a for a in assignments if isinstance(a, atoms.StateDerivative)]
87+
equations = self._equations(derivatives)
88+
89+
state_defaults = [(s.name, self._format_expr(s.value)) for s in self.ode.sorted_states()]
90+
param_defaults = [(p.name, self._format_expr(p.value)) for p in self.ode.parameters]
91+
param_defaults.extend((n, "0.0") for n in extra_params)
92+
93+
parts = [
94+
self.template.header(),
95+
self.template.parameters_block(param_names, extra_params),
96+
self.template.states_block(state_names),
97+
self.template.observed_block(observed),
98+
self.template.equations_block(equations),
99+
self.template.defaults_block(state_defaults, param_defaults),
100+
self.template.ode_system(self.ode.name, state_names, param_names + extra_params),
101+
]
102+
103+
return "\n".join(parts)

src/gotranx/templates/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from . import python
66
from . import jax
77
from . import julia
8+
from . import mtk
89
from . import markdown
910

1011

@@ -195,4 +196,4 @@ def method(
195196
"""
196197

197198

198-
__all__ = ["c", "python", "jax", "julia", "markdown", "Template"]
199+
__all__ = ["c", "python", "jax", "julia", "markdown", "mtk", "Template"]

src/gotranx/templates/mtk.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
from textwrap import dedent, indent
4+
5+
6+
def header() -> str:
7+
return dedent(
8+
"""
9+
using ModelingToolkit
10+
""",
11+
)
12+
13+
14+
def parameters_block(params: list[str], extra_params: list[str]) -> str:
15+
names = " ".join(["t"] + params + extra_params)
16+
return dedent(
17+
f"""
18+
@parameters {names}
19+
D = Differential(t)
20+
""",
21+
)
22+
23+
24+
def states_block(states: list[str]) -> str:
25+
names = " ".join([f"{s}(t)" for s in states])
26+
return dedent(
27+
f"""
28+
@variables {names}
29+
""",
30+
)
31+
32+
33+
def observed_block(entries: list[str]) -> str:
34+
if not entries:
35+
return "observed = []"
36+
37+
body = indent(",\n".join(entries), " ")
38+
return dedent(
39+
f"""
40+
observed = [
41+
{body}
42+
]
43+
""",
44+
)
45+
46+
47+
def equations_block(entries: list[str]) -> str:
48+
body = indent(",\n".join(entries), " ")
49+
return dedent(
50+
f"""
51+
eqs = [
52+
{body}
53+
]
54+
""",
55+
)
56+
57+
58+
def defaults_block(
59+
state_defaults: list[tuple[str, str]], param_defaults: list[tuple[str, str]]
60+
) -> str:
61+
pairs = []
62+
for name, value in state_defaults + param_defaults:
63+
pairs.append(f"{name} => {value}")
64+
65+
if not pairs:
66+
return "defaults = Dict()"
67+
68+
body = indent(",\n".join(pairs), " ")
69+
return dedent(
70+
f"""
71+
defaults = Dict(
72+
{body}
73+
)
74+
""",
75+
)
76+
77+
78+
def ode_system(name: str, state_names: list[str], param_names: list[str]) -> str:
79+
states = ", ".join(state_names)
80+
params = ", ".join(param_names)
81+
return dedent(
82+
f"""
83+
@named {name} = ODESystem(
84+
eqs,
85+
t,
86+
[{states}],
87+
[{params}],
88+
observed = observed,
89+
defaults = defaults,
90+
)
91+
""",
92+
)

0 commit comments

Comments
 (0)