Skip to content

Commit 0d3400f

Browse files
committed
add type hints to components
1 parent 2954417 commit 0d3400f

File tree

6 files changed

+84
-7
lines changed

6 files changed

+84
-7
lines changed

Diff for: pyomo/core/base/block.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import copy
1314
import logging
1415
import sys
@@ -21,6 +22,7 @@
2122
from io import StringIO
2223
from itertools import filterfalse, chain
2324
from operator import itemgetter, attrgetter
25+
from typing import Union, Any, Type
2426

2527
from pyomo.common.autoslots import AutoSlots
2628
from pyomo.common.collections import Mapping
@@ -44,6 +46,7 @@
4446
from pyomo.core.base.indexed_component import (
4547
ActiveIndexedComponent,
4648
UnindexedComponent_set,
49+
IndexedComponent,
4750
)
4851

4952
from pyomo.opt.base import ProblemFormat, guess_format
@@ -539,7 +542,7 @@ def __init__(self, component):
539542
super(_BlockData, self).__setattr__('_decl_order', [])
540543
self._private_data = None
541544

542-
def __getattr__(self, val):
545+
def __getattr__(self, val) -> Union[Component, IndexedComponent, Any]:
543546
if val in ModelComponentFactory:
544547
return _component_decorator(self, ModelComponentFactory.get_class(val))
545548
# Since the base classes don't support getattr, we can just
@@ -548,7 +551,7 @@ def __getattr__(self, val):
548551
"'%s' object has no attribute '%s'" % (self.__class__.__name__, val)
549552
)
550553

551-
def __setattr__(self, name, val):
554+
def __setattr__(self, name: str, val: Union[Component, IndexedComponent, Any]):
552555
"""
553556
Set an attribute of a block data object.
554557
"""
@@ -2007,6 +2010,18 @@ class Block(ActiveIndexedComponent):
20072010
_ComponentDataClass = _BlockData
20082011
_private_data_initializers = defaultdict(lambda: dict)
20092012

2013+
@overload
2014+
def __new__(cls: Type[Block], *args, **kwds) -> Union[ScalarBlock, IndexedBlock]:
2015+
...
2016+
2017+
@overload
2018+
def __new__(cls: Type[ScalarBlock], *args, **kwds) -> ScalarBlock:
2019+
...
2020+
2021+
@overload
2022+
def __new__(cls: Type[IndexedBlock], *args, **kwds) -> IndexedBlock:
2023+
...
2024+
20102025
def __new__(cls, *args, **kwds):
20112026
if cls != Block:
20122027
return super(Block, cls).__new__(cls)
@@ -2251,6 +2266,9 @@ class IndexedBlock(Block):
22512266
def __init__(self, *args, **kwds):
22522267
Block.__init__(self, *args, **kwds)
22532268

2269+
def __getitem__(self, index) -> _BlockData:
2270+
return super().__getitem__(index)
2271+
22542272

22552273
#
22562274
# Deprecated functions.

Diff for: pyomo/core/base/constraint.py

+17
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import sys
1314
import logging
1415
from weakref import ref as weakref_ref
1516
from pyomo.common.pyomo_typing import overload
17+
from typing import Union, Type
1618

1719
from pyomo.common.deprecation import RenamedClass
1820
from pyomo.common.errors import DeveloperError
@@ -728,6 +730,18 @@ class Infeasible(object):
728730
Violated = Infeasible
729731
Satisfied = Feasible
730732

733+
@overload
734+
def __new__(cls: Type[Constraint], *args, **kwds) -> Union[ScalarConstraint, IndexedConstraint]:
735+
...
736+
737+
@overload
738+
def __new__(cls: Type[ScalarConstraint], *args, **kwds) -> ScalarConstraint:
739+
...
740+
741+
@overload
742+
def __new__(cls: Type[IndexedConstraint], *args, **kwds) -> IndexedConstraint:
743+
...
744+
731745
def __new__(cls, *args, **kwds):
732746
if cls != Constraint:
733747
return super(Constraint, cls).__new__(cls)
@@ -1019,6 +1033,9 @@ class IndexedConstraint(Constraint):
10191033
def add(self, index, expr):
10201034
"""Add a constraint with a given index."""
10211035
return self.__setitem__(index, expr)
1036+
1037+
def __getitem__(self, index) -> _GeneralConstraintData:
1038+
return super().__getitem__(index)
10221039

10231040

10241041
@ModelComponentFactory.register("A list of constraint expressions.")

Diff for: pyomo/core/base/indexed_component.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pyomo.core.base as BASE
1919
from pyomo.core.base.indexed_component_slice import IndexedComponent_slice
2020
from pyomo.core.base.initializer import Initializer
21-
from pyomo.core.base.component import Component, ActiveComponent
21+
from pyomo.core.base.component import Component, ActiveComponent, ComponentData
2222
from pyomo.core.base.config import PyomoOptions
2323
from pyomo.core.base.enums import SortComponents
2424
from pyomo.core.base.global_set import UnindexedComponent_set
@@ -606,7 +606,7 @@ def iteritems(self):
606606
"""Return a list (index,data) tuples from the dictionary"""
607607
return self.items()
608608

609-
def __getitem__(self, index):
609+
def __getitem__(self, index) -> ComponentData:
610610
"""
611611
This method returns the data corresponding to the given index.
612612
"""

Diff for: pyomo/core/base/param.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import sys
1314
import types
1415
import logging
1516
from weakref import ref as weakref_ref
1617
from pyomo.common.pyomo_typing import overload
18+
from typing import Union, Type
1719

1820
from pyomo.common.autoslots import AutoSlots
1921
from pyomo.common.deprecation import deprecation_warning, RenamedClass
@@ -291,6 +293,18 @@ class NoValue(object):
291293

292294
pass
293295

296+
@overload
297+
def __new__(cls: Type[Param], *args, **kwds) -> Union[ScalarParam, IndexedParam]:
298+
...
299+
300+
@overload
301+
def __new__(cls: Type[ScalarParam], *args, **kwds) -> ScalarParam:
302+
...
303+
304+
@overload
305+
def __new__(cls: Type[IndexedParam], *args, **kwds) -> IndexedParam:
306+
...
307+
294308
def __new__(cls, *args, **kwds):
295309
if cls != Param:
296310
return super(Param, cls).__new__(cls)
@@ -983,7 +997,7 @@ def _create_objects_for_deepcopy(self, memo, component_list):
983997
# between potentially variable GetItemExpression objects and
984998
# "constant" GetItemExpression objects. That will need to wait for
985999
# the expression rework [JDS; Nov 22].
986-
def __getitem__(self, args):
1000+
def __getitem__(self, args) -> _ParamData:
9871001
try:
9881002
return super().__getitem__(args)
9891003
except:

Diff for: pyomo/core/base/set.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import inspect
1314
import itertools
1415
import logging
1516
import math
1617
import sys
1718
import weakref
1819
from pyomo.common.pyomo_typing import overload
20+
from typing import Union, Type, Any
21+
from collections.abc import Iterator
1922

2023
from pyomo.common.collections import ComponentSet
2124
from pyomo.common.deprecation import deprecated, deprecation_warning, RenamedClass
@@ -569,7 +572,7 @@ def isordered(self):
569572
def subsets(self, expand_all_set_operators=None):
570573
return iter((self,))
571574

572-
def __iter__(self):
575+
def __iter__(self) -> Iterator[Any]:
573576
"""Iterate over the set members
574577
575578
Raises AttributeError for non-finite sets. This must be
@@ -1967,6 +1970,14 @@ class SortedOrder(object):
19671970
_ValidOrderedAuguments = {True, False, InsertionOrder, SortedOrder}
19681971
_UnorderedInitializers = {set}
19691972

1973+
@overload
1974+
def __new__(cls: Type[Set], *args, **kwds) -> Union[_SetData, IndexedSet]:
1975+
...
1976+
1977+
@overload
1978+
def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet:
1979+
...
1980+
19701981
def __new__(cls, *args, **kwds):
19711982
if cls is not Set:
19721983
return super(Set, cls).__new__(cls)
@@ -2373,6 +2384,9 @@ def data(self):
23732384
"Return a dict containing the data() of each Set in this IndexedSet"
23742385
return {k: v.data() for k, v in self.items()}
23752386

2387+
def __getitem__(self, index) -> _SetData:
2388+
return super().__getitem__(index)
2389+
23762390

23772391
class FiniteScalarSet(_FiniteSetData, Set):
23782392
def __init__(self, **kwds):

Diff for: pyomo/core/base/var.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import logging
1314
import sys
1415
from pyomo.common.pyomo_typing import overload
1516
from weakref import ref as weakref_ref
17+
from typing import Union, Type
1618

1719
from pyomo.common.deprecation import RenamedClass
1820
from pyomo.common.log import is_debug_set
@@ -668,6 +670,18 @@ class Var(IndexedComponent, IndexedComponent_NDArrayMixin):
668670

669671
_ComponentDataClass = _GeneralVarData
670672

673+
@overload
674+
def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]:
675+
...
676+
677+
@overload
678+
def __new__(cls: Type[ScalarVar], *args, **kwargs) -> ScalarVar:
679+
...
680+
681+
@overload
682+
def __new__(cls: Type[IndexedVar], *args, **kwargs) -> IndexedVar:
683+
...
684+
671685
def __new__(cls, *args, **kwargs):
672686
if cls is not Var:
673687
return super(Var, cls).__new__(cls)
@@ -1046,7 +1060,7 @@ def domain(self, domain):
10461060
# between potentially variable GetItemExpression objects and
10471061
# "constant" GetItemExpression objects. That will need to wait for
10481062
# the expression rework [JDS; Nov 22].
1049-
def __getitem__(self, args):
1063+
def __getitem__(self, args) -> _GeneralVarData:
10501064
try:
10511065
return super().__getitem__(args)
10521066
except RuntimeError:

0 commit comments

Comments
 (0)