Skip to content

Commit 58665b2

Browse files
committed
Update for autocomplete.
Signed-off-by: James Goppert <james.goppert@gmail.com>
1 parent d98834c commit 58665b2

File tree

6 files changed

+110
-62
lines changed

6 files changed

+110
-62
lines changed

cyecca/dsl/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170
from cyecca.dsl.equations import Assignment, Equation, IfEquation, IfEquationBranch, Reinit, WhenClause
171171

172172
# Expression tree
173-
from cyecca.dsl.expr import Expr, ExprKind
173+
from cyecca.dsl.expr import Expr, ExprKind, ExprLike
174174

175175
# Flat model representation
176176
from cyecca.dsl.flat_model import FlatModel
@@ -221,7 +221,7 @@
221221
terminal,
222222
)
223223
from cyecca.dsl.simulation import SimulationResult, Simulator
224-
from cyecca.dsl.types import DType, Indices, Shape, Var, VarKind
224+
from cyecca.dsl.types import DType, Indices, NumericValue, Shape, Var, VarKind
225225

226226
# Variables
227227
from cyecca.dsl.variables import ArrayDerivativeExpr, DerivativeExpr, SymbolicVar, TimeVar
@@ -246,6 +246,7 @@
246246
"DType",
247247
"Shape",
248248
"Indices",
249+
"NumericValue",
249250
"submodel",
250251
# Free functions (continuous)
251252
"der",
@@ -291,6 +292,7 @@
291292
"FlatModel",
292293
"Expr",
293294
"ExprKind",
295+
"ExprLike",
294296
"Equation",
295297
# Causality analysis
296298
"analyze_causality",

cyecca/dsl/context.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@
3131
from __future__ import annotations
3232

3333
import threading
34-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
34+
from typing import TYPE_CHECKING, Callable, List, Optional, Union
3535

3636
from beartype import beartype
3737

3838
from cyecca.dsl.equations import IfEquation, IfEquationBranch, Reinit, WhenClause
39-
from cyecca.dsl.expr import Expr, to_expr
39+
from cyecca.dsl.expr import Expr, ExprLike, to_expr
40+
from cyecca.dsl.instance import SubmodelProxy
4041
from cyecca.dsl.variables import SymbolicVar
4142

4243
if TYPE_CHECKING:
@@ -427,7 +428,7 @@ def to_when_clause(self) -> "WhenClause":
427428

428429

429430
@beartype
430-
def when(condition: Any) -> WhenContext:
431+
def when(condition: ExprLike) -> WhenContext:
431432
"""
432433
Create a when-clause for event handling (Modelica MLS 8.5).
433434
@@ -453,7 +454,7 @@ def _(m):
453454
454455
Parameters
455456
----------
456-
condition : Expr or SymbolicVar
457+
condition : ExprLike
457458
Boolean condition that triggers the when-clause
458459
459460
Returns
@@ -474,7 +475,7 @@ def _(m):
474475

475476

476477
@beartype
477-
def reinit(var: SymbolicVar, expr: Any) -> Optional[Reinit]:
478+
def reinit(var: SymbolicVar, expr: ExprLike) -> Optional[Reinit]:
478479
"""
479480
Reinitialize a continuous-time state variable at an event.
480481
@@ -489,7 +490,7 @@ def reinit(var: SymbolicVar, expr: Any) -> Optional[Reinit]:
489490
----------
490491
var : SymbolicVar
491492
The state variable to reinitialize
492-
expr : Expr or numeric
493+
expr : ExprLike
493494
The new value expression (can use pre() for previous values)
494495
495496
Returns
@@ -529,7 +530,7 @@ def reinit(var: SymbolicVar, expr: Any) -> Optional[Reinit]:
529530

530531

531532
@beartype
532-
def connect(a: Any, b: Any) -> None:
533+
def connect(a: SubmodelProxy, b: SubmodelProxy) -> None:
533534
"""
534535
Connect two connectors, generating appropriate connection equations.
535536
@@ -598,19 +599,12 @@ def _(m):
598599
where positive flow is into the component.
599600
"""
600601
from cyecca.dsl.equations import Equation
601-
from cyecca.dsl.instance import SubmodelProxy
602602

603603
ctx = get_current_equation_context()
604604
if ctx is None:
605605
raise RuntimeError("connect() can only be used inside an @equations block")
606606

607-
# Validate that both are SubmodelProxy (connector instances)
608-
if not isinstance(a, SubmodelProxy):
609-
raise TypeError(f"connect() first argument must be a connector, got {type(a)}")
610-
if not isinstance(b, SubmodelProxy):
611-
raise TypeError(f"connect() second argument must be a connector, got {type(b)}")
612-
613-
# Get the connector metadata
607+
# Get the connector metadata (beartype already validates SubmodelProxy type)
614608
a_instance = a._instance
615609
b_instance = b._instance
616610
a_metadata = a_instance._metadata
@@ -828,7 +822,7 @@ def execute_equations_method(
828822

829823

830824
@beartype
831-
def if_eq(condition: Any) -> IfContext:
825+
def if_eq(condition: ExprLike) -> IfContext:
832826
"""
833827
Create an if-equation for conditional equations (Modelica MLS 8.3.4).
834828
@@ -855,7 +849,7 @@ def _(m):
855849
856850
Parameters
857851
----------
858-
condition : Expr or SymbolicVar
852+
condition : ExprLike
859853
Boolean condition that selects which equations are active
860854
861855
Returns
@@ -877,7 +871,7 @@ def _(m):
877871

878872

879873
@beartype
880-
def elseif_eq(condition: Any) -> IfContext:
874+
def elseif_eq(condition: ExprLike) -> IfContext:
881875
"""
882876
Create an elseif branch for an if-equation.
883877
@@ -894,7 +888,7 @@ def _(m):
894888
895889
Parameters
896890
----------
897-
condition : Expr or SymbolicVar
891+
condition : ExprLike
898892
Boolean condition for this elseif branch
899893
900894
Returns

cyecca/dsl/decorators.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,29 @@
3232
from __future__ import annotations
3333

3434
from dataclasses import dataclass, field
35-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
35+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, TypeVar, Union, cast
36+
37+
# Python 3.11+ has dataclass_transform in typing, earlier versions need typing_extensions
38+
try:
39+
from typing import dataclass_transform
40+
except ImportError:
41+
try:
42+
from typing_extensions import dataclass_transform
43+
except ImportError:
44+
# Fallback: no-op decorator if not available
45+
def dataclass_transform(**kwargs): # type: ignore[misc]
46+
def decorator(cls_or_fn): # type: ignore[no-untyped-def]
47+
return cls_or_fn
48+
49+
return decorator
50+
3651

3752
import numpy as np
3853
from beartype import beartype
3954

55+
# TypeVar for preserving class type through decorators
56+
_T = TypeVar("_T")
57+
4058
from cyecca.dsl.context import (
4159
execute_algorithm_method,
4260
execute_equations_method,
@@ -47,10 +65,15 @@
4765
)
4866
from cyecca.dsl.equations import ArrayEquation, Assignment, Equation, WhenClause
4967
from cyecca.dsl.instance import ModelInstance
50-
from cyecca.dsl.types import DType, Shape, SubmodelField, Var, VarKind
68+
from cyecca.dsl.types import DType, NumericValue, Shape, SubmodelField, Var, VarKind
69+
from cyecca.dsl.variables import SymbolicVar
5170

71+
# For IDE autocomplete: Real/Integer/Boolean/String return SymbolicVar for type checking
72+
# but Var at runtime. This allows m.theta to autocomplete as SymbolicVar.
5273
if TYPE_CHECKING:
53-
pass
74+
_VarReturn = SymbolicVar
75+
else:
76+
_VarReturn = Var
5477

5578

5679
# =============================================================================
@@ -114,7 +137,7 @@ def var(
114137
protected: bool = False,
115138
# Connector prefixes (Modelica MLS Ch. 9)
116139
flow: bool = False,
117-
) -> Var:
140+
) -> _VarReturn:
118141
"""
119142
Declare a variable in a Cyecca model.
120143
@@ -231,7 +254,7 @@ def Real(
231254
protected: bool = False,
232255
# Connector
233256
flow: bool = False,
234-
) -> Var:
257+
) -> _VarReturn:
235258
"""
236259
Declare a Real (floating-point) variable.
237260
@@ -326,7 +349,7 @@ def Integer(
326349
constant: bool = False,
327350
# Visibility
328351
protected: bool = False,
329-
) -> Var:
352+
) -> _VarReturn:
330353
"""
331354
Declare an Integer variable.
332355
@@ -408,7 +431,7 @@ def Boolean(
408431
constant: bool = False,
409432
# Visibility
410433
protected: bool = False,
411-
) -> Var:
434+
) -> _VarReturn:
412435
"""
413436
Declare a Boolean variable.
414437
@@ -480,7 +503,7 @@ def String(
480503
constant: bool = False,
481504
# Visibility
482505
protected: bool = False,
483-
) -> Var:
506+
) -> _VarReturn:
484507
"""
485508
Declare a String variable.
486509
@@ -536,7 +559,7 @@ def String(
536559

537560

538561
@beartype
539-
def submodel(model_class: Type, **overrides: Any) -> SubmodelField:
562+
def submodel(model_class: Type, **overrides: NumericValue) -> SubmodelField:
540563
"""
541564
Declare a submodel (nested model) with optional parameter overrides.
542565
@@ -547,7 +570,7 @@ def submodel(model_class: Type, **overrides: Any) -> SubmodelField:
547570
----------
548571
model_class : Type
549572
The model class to instantiate as a submodel
550-
**overrides : Any
573+
**overrides : NumericValue
551574
Parameter value overrides. The parameter names must match
552575
parameters defined in the submodel class.
553576
@@ -575,8 +598,9 @@ def submodel(model_class: Type, **overrides: Any) -> SubmodelField:
575598
# =============================================================================
576599

577600

601+
@dataclass_transform(field_specifiers=(Real, Integer, Boolean, String, var))
578602
@beartype
579-
def model(cls: Type[Any]) -> Type[Any]:
603+
def model(cls: Type[_T]) -> Type[_T]:
580604
"""
581605
Decorator to convert a class into a Cyecca model.
582606
@@ -670,6 +694,8 @@ class ModelClass(ModelInstance):
670694
__name__ = cls.__name__
671695
__qualname__ = cls.__qualname__
672696
__module__ = cls.__module__
697+
# Copy annotations from original class for IDE autocomplete
698+
__annotations__ = getattr(cls, "__annotations__", {})
673699

674700
_equations_methods = equations_methods
675701
_initial_equations_methods = initial_equations_methods
@@ -704,16 +730,17 @@ def get_algorithm(self) -> List[Assignment]:
704730

705731
ModelClass._dsl_metadata = metadata
706732

707-
return ModelClass
733+
return cast(Type[_T], ModelClass)
708734

709735

710736
# =============================================================================
711737
# @function decorator - Modelica functions (Ch. 12)
712738
# =============================================================================
713739

714740

741+
@dataclass_transform(field_specifiers=(Real, Integer, Boolean, String, var))
715742
@beartype
716-
def function(cls: Type[Any]) -> Type[Any]:
743+
def function(cls: Type[_T]) -> Type[_T]:
717744
"""
718745
Decorator to convert a class into a Cyecca function.
719746
@@ -787,8 +814,9 @@ def get_function_metadata(self) -> FunctionMetadata:
787814
# =============================================================================
788815

789816

817+
@dataclass_transform(field_specifiers=(Real, Integer, Boolean, String, var))
790818
@beartype
791-
def block(cls: Type[Any]) -> Type[Any]:
819+
def block(cls: Type[_T]) -> Type[_T]:
792820
"""
793821
Decorator to convert a class into a Cyecca block.
794822
@@ -832,8 +860,9 @@ def block(cls: Type[Any]) -> Type[Any]:
832860
# =============================================================================
833861

834862

863+
@dataclass_transform(field_specifiers=(Real, Integer, Boolean, String, var))
835864
@beartype
836-
def connector(cls: Type[Any]) -> Type[Any]:
865+
def connector(cls: Type[_T]) -> Type[_T]:
837866
"""
838867
Decorator to convert a class into a Cyecca connector.
839868

cyecca/dsl/expr.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,37 @@
3131

3232
from dataclasses import dataclass
3333
from enum import Enum, auto
34-
from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple
34+
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
3535

3636
import numpy as np
3737
from beartype import beartype
3838

3939
from cyecca.dsl.types import Indices, Shape
4040

4141
if TYPE_CHECKING:
42+
from cyecca.dsl.algorithm import AlgorithmVar
4243
from cyecca.dsl.variables import DerivativeExpr, SymbolicVar, TimeVar
4344

45+
# Type alias for anything that can be converted to an Expr
46+
# Used in operators and functions that accept expressions
47+
# Note: At runtime, beartype uses object (accepting anything) and to_expr() validates.
48+
# For static type checking, the Union provides proper type hints.
49+
if TYPE_CHECKING:
50+
ExprLike = Union[
51+
"Expr",
52+
"SymbolicVar",
53+
"DerivativeExpr",
54+
"TimeVar",
55+
"AlgorithmVar",
56+
float,
57+
int,
58+
List["ExprLike"],
59+
np.ndarray,
60+
]
61+
else:
62+
# At runtime, accept any object - to_expr() will validate and convert
63+
ExprLike = object
64+
4465

4566
class ExprKind(Enum):
4667
"""Kinds of expression nodes."""
@@ -330,7 +351,7 @@ def __hash__(self) -> int:
330351

331352

332353
@beartype
333-
def to_expr(x: Any) -> Expr:
354+
def to_expr(x: ExprLike) -> Expr:
334355
"""Convert various types to Expr."""
335356
if isinstance(x, Expr):
336357
return x

0 commit comments

Comments
 (0)