3838import numpy as np
3939from beartype import beartype
4040
41+ from cyecca .dsl .causality import SortedSystem
4142from cyecca .dsl .equations import Equation , Reinit , WhenClause
4243from cyecca .dsl .expr import Expr , ExprKind
4344from cyecca .dsl .flat_model import FlatModel
4445from cyecca .dsl .simulation import SimulationResult , Simulator
4546from cyecca .dsl .variables import SymbolicVar
4647
48+ # Type for model input - either raw FlatModel or analyzed SortedSystem
49+ ModelInput = Union [FlatModel , SortedSystem ]
50+
4751
4852class SymbolicType (Enum ):
4953 """CasADi symbolic type selection."""
@@ -256,9 +260,9 @@ def _create_symbols(self) -> None:
256260 # Time symbol
257261 self .t_sym = self ._sym ("t" )
258262
259- # Check if model uses implicit DAE form (der() appears in equations RHS)
260- # Use is_explicit from flattener (no longer detect here)
261- self .is_implicit_dae = not model . is_explicit
263+ # Detect if model uses implicit DAE form
264+ # Implicit if any equation has der() on RHS or non-pure der() on LHS
265+ self .is_implicit_dae = self . _detect_implicit_dae ()
262266
263267 # Combined lookup (NOT including outputs - they get substituted)
264268 self .base_syms = {
@@ -269,6 +273,32 @@ def _create_symbols(self) -> None:
269273 ** self .discrete_syms ,
270274 }
271275
276+ def _detect_implicit_dae (self ) -> bool :
277+ """Detect if the model requires implicit DAE form.
278+
279+ Returns True if any equation:
280+ 1. Has der() on RHS (e.g., output == der(x) + der(y))
281+ 2. Has der() on LHS but not in pure form (e.g., m * der(v) == g)
282+ """
283+ from cyecca .dsl .expr import find_derivatives
284+
285+ for eq in self .model .equations :
286+ # Skip output equations
287+ if eq .lhs .kind == ExprKind .VARIABLE and eq .lhs .name in self .model .output_equations :
288+ continue
289+
290+ rhs_derivs = find_derivatives (eq .rhs )
291+ if rhs_derivs :
292+ # der() on RHS - implicit
293+ return True
294+
295+ lhs_derivs = find_derivatives (eq .lhs )
296+ if lhs_derivs and not eq .is_derivative :
297+ # der() on LHS but not pure der(x) == rhs form
298+ return True
299+
300+ return False
301+
272302 def _resolve_indexed_variable (self , name : str ) -> SymT :
273303 """
274304 Resolve an indexed variable name like 'pos[0]' or 'R[0,1]'.
@@ -424,25 +454,25 @@ def expr_to_casadi_when(self, expr: Expr) -> SymT:
424454 def _build_state_derivatives (self ) -> List [SymT ]:
425455 """Build the state derivative vector for explicit ODE form.
426456
427- Only valid when model.is_explicit=True. Each differential equation
428- must be in form der(x) == rhs where var_name is set .
457+ Only valid when system is explicit. Each state must have exactly one
458+ equation of form der(x) == rhs.
429459 """
430460 model = self .model
431461 state_derivs : List [SymT ] = []
432462
433463 # Build lookup from var_name to rhs for explicit equations
434464 deriv_rhs_map : Dict [str , Expr ] = {}
435- for eq in model .differential_equations :
465+ for eq in model .equations :
436466 if eq .is_derivative and eq .var_name :
437467 deriv_rhs_map [eq .var_name ] = eq .rhs
438468
439469 for name in model .state_names :
440470 shape = self .state_shapes .get (name , ())
441471 size = self ._shape_to_size (shape )
442472
443- # Check for array differential equation (MX backend)
444- if self .is_mx and name in model .array_differential_equations :
445- arr_eq = model .array_differential_equations [name ]
473+ # Check for array equation (MX backend)
474+ if self .is_mx and name in model .array_equations :
475+ arr_eq = model .array_equations [name ]
446476 rhs = arr_eq ["rhs" ]
447477 # RHS is a SymbolicVar - get its symbol
448478 if rhs .base_name in self .base_syms :
@@ -503,12 +533,29 @@ def _build_outputs(self) -> List[SymT]:
503533 return y_exprs
504534
505535 def _build_algebraic_residuals (self ) -> List [SymT ]:
506- """Build algebraic equation residuals (0 = lhs - rhs)."""
536+ """Build algebraic equation residuals (0 = lhs - rhs).
537+
538+ Algebraic equations are those that don't contain der() on LHS
539+ and are not output equations.
540+ """
541+ from cyecca .dsl .expr import find_derivatives
542+
507543 residuals : List [SymT ] = []
508- for eq in self .model .algebraic_equations :
544+ for eq in self .model .equations :
545+ # Skip output equations
546+ if eq .lhs .kind == ExprKind .VARIABLE and eq .lhs .name in self .model .output_equations :
547+ continue
548+
549+ # Skip differential equations (those with der() on LHS)
550+ lhs_derivs = find_derivatives (eq .lhs )
551+ if lhs_derivs :
552+ continue
553+
554+ # This is an algebraic equation
509555 lhs = self .expr_to_casadi (eq .lhs )
510556 rhs = self .expr_to_casadi (eq .rhs )
511557 residuals .append (lhs - rhs )
558+
512559 return residuals
513560
514561 def _build_differential_residuals (self ) -> List [SymT ]:
0 commit comments