Skip to content

Commit cd03d5b

Browse files
committed
Add initial equation support.
Signed-off-by: James Goppert <james.goppert@gmail.com>
1 parent fcb6f89 commit cd03d5b

21 files changed

+4077
-6856
lines changed

cyecca/dsl/ROADMAP.md

Lines changed: 208 additions & 795 deletions
Large diffs are not rendered by default.

cyecca/dsl/__init__.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,28 @@
137137
submodel,
138138
var,
139139
)
140-
from cyecca.dsl.operators import abs, acos, asin, atan, atan2, cos, exp, log, sin, sqrt, tan
140+
from cyecca.dsl.operators import (
141+
abs,
142+
acos,
143+
asin,
144+
atan,
145+
atan2,
146+
ceil,
147+
cos,
148+
cosh,
149+
exp,
150+
floor,
151+
log,
152+
log10,
153+
max,
154+
min,
155+
sign,
156+
sin,
157+
sinh,
158+
sqrt,
159+
tan,
160+
tanh,
161+
)
141162
from cyecca.dsl.simulation import SimulationResult, Simulator
142163
from cyecca.dsl.types import DType, Indices, Shape, Var, VarKind
143164

@@ -193,5 +214,14 @@
193214
"sqrt",
194215
"exp",
195216
"log",
217+
"log10",
196218
"abs",
219+
"sign",
220+
"floor",
221+
"ceil",
222+
"sinh",
223+
"cosh",
224+
"tanh",
225+
"min",
226+
"max",
197227
]

cyecca/dsl/backends/casadi.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,36 @@ def expr_to_casadi(expr: Expr) -> ca.SX:
197197
elif expr.kind == ExprKind.LOG:
198198
return ca.log(expr_to_casadi(expr.children[0]))
199199

200+
elif expr.kind == ExprKind.LOG10:
201+
return ca.log10(expr_to_casadi(expr.children[0]))
202+
200203
elif expr.kind == ExprKind.ABS:
201204
return ca.fabs(expr_to_casadi(expr.children[0]))
202205

206+
elif expr.kind == ExprKind.SIGN:
207+
return ca.sign(expr_to_casadi(expr.children[0]))
208+
209+
elif expr.kind == ExprKind.FLOOR:
210+
return ca.floor(expr_to_casadi(expr.children[0]))
211+
212+
elif expr.kind == ExprKind.CEIL:
213+
return ca.ceil(expr_to_casadi(expr.children[0]))
214+
215+
elif expr.kind == ExprKind.SINH:
216+
return ca.sinh(expr_to_casadi(expr.children[0]))
217+
218+
elif expr.kind == ExprKind.COSH:
219+
return ca.cosh(expr_to_casadi(expr.children[0]))
220+
221+
elif expr.kind == ExprKind.TANH:
222+
return ca.tanh(expr_to_casadi(expr.children[0]))
223+
224+
elif expr.kind == ExprKind.MIN:
225+
return ca.fmin(expr_to_casadi(expr.children[0]), expr_to_casadi(expr.children[1]))
226+
227+
elif expr.kind == ExprKind.MAX:
228+
return ca.fmax(expr_to_casadi(expr.children[0]), expr_to_casadi(expr.children[1]))
229+
203230
# Relational operators
204231
elif expr.kind == ExprKind.LT:
205232
return expr_to_casadi(expr.children[0]) < expr_to_casadi(expr.children[1])
@@ -451,9 +478,36 @@ def expr_to_casadi_mx(expr: Expr) -> ca.MX:
451478
elif expr.kind == ExprKind.LOG:
452479
return ca.log(expr_to_casadi_mx(expr.children[0]))
453480

481+
elif expr.kind == ExprKind.LOG10:
482+
return ca.log10(expr_to_casadi_mx(expr.children[0]))
483+
454484
elif expr.kind == ExprKind.ABS:
455485
return ca.fabs(expr_to_casadi_mx(expr.children[0]))
456486

487+
elif expr.kind == ExprKind.SIGN:
488+
return ca.sign(expr_to_casadi_mx(expr.children[0]))
489+
490+
elif expr.kind == ExprKind.FLOOR:
491+
return ca.floor(expr_to_casadi_mx(expr.children[0]))
492+
493+
elif expr.kind == ExprKind.CEIL:
494+
return ca.ceil(expr_to_casadi_mx(expr.children[0]))
495+
496+
elif expr.kind == ExprKind.SINH:
497+
return ca.sinh(expr_to_casadi_mx(expr.children[0]))
498+
499+
elif expr.kind == ExprKind.COSH:
500+
return ca.cosh(expr_to_casadi_mx(expr.children[0]))
501+
502+
elif expr.kind == ExprKind.TANH:
503+
return ca.tanh(expr_to_casadi_mx(expr.children[0]))
504+
505+
elif expr.kind == ExprKind.MIN:
506+
return ca.fmin(expr_to_casadi_mx(expr.children[0]), expr_to_casadi_mx(expr.children[1]))
507+
508+
elif expr.kind == ExprKind.MAX:
509+
return ca.fmax(expr_to_casadi_mx(expr.children[0]), expr_to_casadi_mx(expr.children[1]))
510+
457511
# Relational operators
458512
elif expr.kind == ExprKind.LT:
459513
return expr_to_casadi_mx(expr.children[0]) < expr_to_casadi_mx(expr.children[1])

cyecca/dsl/model.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,16 @@ class ExprKind(Enum):
136136
SQRT = auto()
137137
EXP = auto()
138138
LOG = auto()
139+
LOG10 = auto() # Base-10 logarithm
139140
ABS = auto()
141+
SIGN = auto() # Sign function (-1, 0, or 1)
142+
FLOOR = auto() # Floor function
143+
CEIL = auto() # Ceiling function
144+
SINH = auto() # Hyperbolic sine
145+
COSH = auto() # Hyperbolic cosine
146+
TANH = auto() # Hyperbolic tangent
147+
MIN = auto() # Minimum of two values
148+
MAX = auto() # Maximum of two values
140149

141150

142151
@dataclass(frozen=True)
@@ -197,11 +206,22 @@ def __repr__(self) -> str:
197206
ExprKind.SQRT,
198207
ExprKind.EXP,
199208
ExprKind.LOG,
209+
ExprKind.LOG10,
200210
ExprKind.ABS,
211+
ExprKind.SIGN,
212+
ExprKind.FLOOR,
213+
ExprKind.CEIL,
214+
ExprKind.SINH,
215+
ExprKind.COSH,
216+
ExprKind.TANH,
201217
):
202218
return f"{self.kind.name.lower()}({self.children[0]})"
203219
elif self.kind == ExprKind.ATAN2:
204220
return f"atan2({self.children[0]}, {self.children[1]})"
221+
elif self.kind == ExprKind.MIN:
222+
return f"min({self.children[0]}, {self.children[1]})"
223+
elif self.kind == ExprKind.MAX:
224+
return f"max({self.children[0]}, {self.children[1]})"
205225
elif self.kind == ExprKind.PRE:
206226
return f"pre({self.name})"
207227
elif self.kind == ExprKind.EDGE:
@@ -1683,6 +1703,46 @@ def algorithm(m) -> Generator[Assignment, None, None]:
16831703
return
16841704
yield # Make this a generator
16851705

1706+
def initial_equations(self) -> Generator[Equation, None, None]:
1707+
"""
1708+
Override this method to define initial equations.
1709+
1710+
Initial equations (Modelica: `initial equation` section) are used to
1711+
specify initial conditions for simulation. They are solved once at
1712+
t=0 to determine initial values of states and algebraic variables.
1713+
1714+
This provides more flexibility than just using `start` values:
1715+
- Can specify relationships between initial values
1716+
- Can use equations rather than just fixed values
1717+
- Can leave some variables to be computed from others
1718+
1719+
Example
1720+
-------
1721+
>>> @model
1722+
... class Pendulum:
1723+
... theta = var()
1724+
... omega = var()
1725+
...
1726+
... def equations(m):
1727+
... yield der(m.theta) == m.omega
1728+
... yield der(m.omega) == -9.81 * sin(m.theta)
1729+
...
1730+
... def initial_equations(m):
1731+
... yield m.theta == 0.5 # Initial angle
1732+
... yield m.omega == 0.0 # Start at rest
1733+
1734+
Notes
1735+
-----
1736+
Modelica Spec: Section 8.6 - Initialization, Initial Equation, and Initial Algorithm
1737+
1738+
In Modelica, initial equations form a system that is solved to find
1739+
consistent initial values. Variables with `fixed=True` use their `start`
1740+
values as fixed constraints. Other variables are determined by the
1741+
initial equation system.
1742+
"""
1743+
return
1744+
yield # Make this a generator
1745+
16861746
def flatten(self, expand_arrays: bool = True) -> "FlatModel":
16871747
"""
16881748
Flatten the model into a backend-agnostic representation.
@@ -1763,6 +1823,34 @@ def flatten(self, expand_arrays: bool = True) -> "FlatModel":
17631823
else:
17641824
raise TypeError(f"Expected Assignment in algorithm(), got {type(assign)}")
17651825

1826+
# Collect initial equations
1827+
raw_initial_equations = self.initial_equations()
1828+
initial_equations_list: List[Equation] = []
1829+
1830+
for eq in raw_initial_equations:
1831+
if isinstance(eq, Equation):
1832+
initial_equations_list.append(eq)
1833+
elif isinstance(eq, ArrayEquation):
1834+
if expand_arrays:
1835+
initial_equations_list.extend(eq.expand())
1836+
# For non-expanded case, we could add array initial equations
1837+
# but for now just expand them
1838+
else:
1839+
initial_equations_list.extend(eq.expand())
1840+
else:
1841+
raise TypeError(f"Expected Equation in initial_equations(), got {type(eq)}")
1842+
1843+
# Collect initial equations from submodels
1844+
for sub_name, sub_instance in self._submodels.items():
1845+
for eq in sub_instance.initial_equations():
1846+
if isinstance(eq, Equation):
1847+
prefixed_eq = eq._prefix_names(sub_name)
1848+
initial_equations_list.append(prefixed_eq)
1849+
elif isinstance(eq, ArrayEquation):
1850+
for scalar_eq in eq.expand():
1851+
prefixed_eq = scalar_eq._prefix_names(sub_name)
1852+
initial_equations_list.append(prefixed_eq)
1853+
17661854
# Find all derivatives (der(x)) used in equations to identify states
17671855
derivatives_used: set[str] = set()
17681856
for eq in equations:
@@ -1991,6 +2079,7 @@ def flatten(self, expand_arrays: bool = True) -> "FlatModel":
19912079
input_defaults=input_defaults,
19922080
discrete_defaults=discrete_defaults,
19932081
param_defaults=param_defaults,
2082+
initial_equations=initial_equations_list,
19942083
algorithm_assignments=algorithm_assignments,
19952084
algorithm_locals=algorithm_locals,
19962085
expand_arrays=expand_arrays,
@@ -2046,6 +2135,7 @@ def model(cls: Type[Any]) -> Type[Any]:
20462135
# non-conformant. Output variables are just vars with output=True flag.
20472136
original_equations = getattr(cls, "equations", None)
20482137
original_algorithm = getattr(cls, "algorithm", None)
2138+
original_initial_equations = getattr(cls, "initial_equations", None)
20492139

20502140
# Deprecation check: warn if user defines output_equations (non-Modelica)
20512141
if hasattr(cls, "output_equations"):
@@ -2077,6 +2167,10 @@ def algorithm(self) -> Generator[Assignment, None, None]:
20772167
if original_algorithm is not None:
20782168
yield from original_algorithm(self)
20792169

2170+
def initial_equations(self) -> Generator[Equation, None, None]:
2171+
if original_initial_equations is not None:
2172+
yield from original_initial_equations(self)
2173+
20802174
# Copy the metadata reference
20812175
ModelClass._dsl_metadata = metadata
20822176

@@ -2396,6 +2490,10 @@ class FlatModel:
23962490
discrete_defaults: Dict[str, Any]
23972491
param_defaults: Dict[str, Any]
23982492

2493+
# Initial equations (Modelica: initial equation section)
2494+
# These are solved once at t=0 to determine initial values
2495+
initial_equations: List[Equation] = field(default_factory=list)
2496+
23992497
# Array derivative equations (when expand_arrays=False)
24002498
# For CasADi MX backend: keeps array structure for efficient matrix operations
24012499
# Key is base variable name (e.g., 'pos'), value is {'shape': (3,), 'rhs': SymbolicVar}

0 commit comments

Comments
 (0)