Skip to content

Commit e3b18b9

Browse files
committed
Add analyze causality.
Signed-off-by: James Goppert <james.goppert@gmail.com>
1 parent 45d4005 commit e3b18b9

File tree

8 files changed

+533
-115
lines changed

8 files changed

+533
-115
lines changed

cyecca/dsl/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@
150150
# Flat model representation
151151
from cyecca.dsl.flat_model import FlatModel
152152

153+
# Causality analysis
154+
from cyecca.dsl.causality import SolvedEquation, ImplicitBlock, SortedSystem, analyze_causality
155+
153156
# Model instance and alias
154157
from cyecca.dsl.instance import Model, ModelInstance
155158
from cyecca.dsl.operators import (
@@ -235,6 +238,11 @@
235238
"Expr",
236239
"ExprKind",
237240
"Equation",
241+
# Causality analysis
242+
"analyze_causality",
243+
"SortedSystem",
244+
"SolvedEquation",
245+
"ImplicitBlock",
238246
# Simulation
239247
"SimulationResult",
240248
"Simulator",

cyecca/dsl/backends/casadi.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@
3838
import numpy as np
3939
from beartype import beartype
4040

41+
from cyecca.dsl.causality import SortedSystem
4142
from cyecca.dsl.equations import Equation, Reinit, WhenClause
4243
from cyecca.dsl.expr import Expr, ExprKind
4344
from cyecca.dsl.flat_model import FlatModel
4445
from cyecca.dsl.simulation import SimulationResult, Simulator
4546
from cyecca.dsl.variables import SymbolicVar
4647

48+
# Type for model input - either raw FlatModel or analyzed SortedSystem
49+
ModelInput = Union[FlatModel, SortedSystem]
50+
4751

4852
class SymbolicType(Enum):
4953
"""CasADi symbolic type selection."""
@@ -256,9 +260,9 @@ def _create_symbols(self) -> None:
256260
# Time symbol
257261
self.t_sym = self._sym("t")
258262

259-
# Check if model uses implicit DAE form (der() appears in equations RHS)
260-
# Use is_explicit from flattener (no longer detect here)
261-
self.is_implicit_dae = not model.is_explicit
263+
# Detect if model uses implicit DAE form
264+
# Implicit if any equation has der() on RHS or non-pure der() on LHS
265+
self.is_implicit_dae = self._detect_implicit_dae()
262266

263267
# Combined lookup (NOT including outputs - they get substituted)
264268
self.base_syms = {
@@ -269,6 +273,32 @@ def _create_symbols(self) -> None:
269273
**self.discrete_syms,
270274
}
271275

276+
def _detect_implicit_dae(self) -> bool:
277+
"""Detect if the model requires implicit DAE form.
278+
279+
Returns True if any equation:
280+
1. Has der() on RHS (e.g., output == der(x) + der(y))
281+
2. Has der() on LHS but not in pure form (e.g., m * der(v) == g)
282+
"""
283+
from cyecca.dsl.expr import find_derivatives
284+
285+
for eq in self.model.equations:
286+
# Skip output equations
287+
if eq.lhs.kind == ExprKind.VARIABLE and eq.lhs.name in self.model.output_equations:
288+
continue
289+
290+
rhs_derivs = find_derivatives(eq.rhs)
291+
if rhs_derivs:
292+
# der() on RHS - implicit
293+
return True
294+
295+
lhs_derivs = find_derivatives(eq.lhs)
296+
if lhs_derivs and not eq.is_derivative:
297+
# der() on LHS but not pure der(x) == rhs form
298+
return True
299+
300+
return False
301+
272302
def _resolve_indexed_variable(self, name: str) -> SymT:
273303
"""
274304
Resolve an indexed variable name like 'pos[0]' or 'R[0,1]'.
@@ -424,25 +454,25 @@ def expr_to_casadi_when(self, expr: Expr) -> SymT:
424454
def _build_state_derivatives(self) -> List[SymT]:
425455
"""Build the state derivative vector for explicit ODE form.
426456
427-
Only valid when model.is_explicit=True. Each differential equation
428-
must be in form der(x) == rhs where var_name is set.
457+
Only valid when system is explicit. Each state must have exactly one
458+
equation of form der(x) == rhs.
429459
"""
430460
model = self.model
431461
state_derivs: List[SymT] = []
432462

433463
# Build lookup from var_name to rhs for explicit equations
434464
deriv_rhs_map: Dict[str, Expr] = {}
435-
for eq in model.differential_equations:
465+
for eq in model.equations:
436466
if eq.is_derivative and eq.var_name:
437467
deriv_rhs_map[eq.var_name] = eq.rhs
438468

439469
for name in model.state_names:
440470
shape = self.state_shapes.get(name, ())
441471
size = self._shape_to_size(shape)
442472

443-
# Check for array differential equation (MX backend)
444-
if self.is_mx and name in model.array_differential_equations:
445-
arr_eq = model.array_differential_equations[name]
473+
# Check for array equation (MX backend)
474+
if self.is_mx and name in model.array_equations:
475+
arr_eq = model.array_equations[name]
446476
rhs = arr_eq["rhs"]
447477
# RHS is a SymbolicVar - get its symbol
448478
if rhs.base_name in self.base_syms:
@@ -503,12 +533,29 @@ def _build_outputs(self) -> List[SymT]:
503533
return y_exprs
504534

505535
def _build_algebraic_residuals(self) -> List[SymT]:
506-
"""Build algebraic equation residuals (0 = lhs - rhs)."""
536+
"""Build algebraic equation residuals (0 = lhs - rhs).
537+
538+
Algebraic equations are those that don't contain der() on LHS
539+
and are not output equations.
540+
"""
541+
from cyecca.dsl.expr import find_derivatives
542+
507543
residuals: List[SymT] = []
508-
for eq in self.model.algebraic_equations:
544+
for eq in self.model.equations:
545+
# Skip output equations
546+
if eq.lhs.kind == ExprKind.VARIABLE and eq.lhs.name in self.model.output_equations:
547+
continue
548+
549+
# Skip differential equations (those with der() on LHS)
550+
lhs_derivs = find_derivatives(eq.lhs)
551+
if lhs_derivs:
552+
continue
553+
554+
# This is an algebraic equation
509555
lhs = self.expr_to_casadi(eq.lhs)
510556
rhs = self.expr_to_casadi(eq.rhs)
511557
residuals.append(lhs - rhs)
558+
512559
return residuals
513560

514561
def _build_differential_residuals(self) -> List[SymT]:

0 commit comments

Comments
 (0)