Skip to content

Commit b06ddea

Browse files
authored
Merge pull request #3173 from michaelbynum/typing
add type hints to components
2 parents f8b0239 + c314992 commit b06ddea

File tree

6 files changed

+84
-8
lines changed

6 files changed

+84
-8
lines changed

Diff for: pyomo/core/base/block.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import copy
1314
import logging
1415
import sys
@@ -21,6 +22,7 @@
2122
from io import StringIO
2223
from itertools import filterfalse, chain
2324
from operator import itemgetter, attrgetter
25+
from typing import Union, Any, Type
2426

2527
from pyomo.common.autoslots import AutoSlots
2628
from pyomo.common.collections import Mapping
@@ -44,6 +46,7 @@
4446
from pyomo.core.base.indexed_component import (
4547
ActiveIndexedComponent,
4648
UnindexedComponent_set,
49+
IndexedComponent,
4750
)
4851

4952
from pyomo.opt.base import ProblemFormat, guess_format
@@ -539,7 +542,7 @@ def __init__(self, component):
539542
super(_BlockData, self).__setattr__('_decl_order', [])
540543
self._private_data = None
541544

542-
def __getattr__(self, val):
545+
def __getattr__(self, val) -> Union[Component, IndexedComponent, Any]:
543546
if val in ModelComponentFactory:
544547
return _component_decorator(self, ModelComponentFactory.get_class(val))
545548
# Since the base classes don't support getattr, we can just
@@ -548,7 +551,7 @@ def __getattr__(self, val):
548551
"'%s' object has no attribute '%s'" % (self.__class__.__name__, val)
549552
)
550553

551-
def __setattr__(self, name, val):
554+
def __setattr__(self, name: str, val: Union[Component, IndexedComponent, Any]):
552555
"""
553556
Set an attribute of a block data object.
554557
"""
@@ -2007,6 +2010,17 @@ class Block(ActiveIndexedComponent):
20072010
_ComponentDataClass = _BlockData
20082011
_private_data_initializers = defaultdict(lambda: dict)
20092012

2013+
@overload
2014+
def __new__(
2015+
cls: Type[Block], *args, **kwds
2016+
) -> Union[ScalarBlock, IndexedBlock]: ...
2017+
2018+
@overload
2019+
def __new__(cls: Type[ScalarBlock], *args, **kwds) -> ScalarBlock: ...
2020+
2021+
@overload
2022+
def __new__(cls: Type[IndexedBlock], *args, **kwds) -> IndexedBlock: ...
2023+
20102024
def __new__(cls, *args, **kwds):
20112025
if cls != Block:
20122026
return super(Block, cls).__new__(cls)
@@ -2251,6 +2265,11 @@ class IndexedBlock(Block):
22512265
def __init__(self, *args, **kwds):
22522266
Block.__init__(self, *args, **kwds)
22532267

2268+
@overload
2269+
def __getitem__(self, index) -> _BlockData: ...
2270+
2271+
__getitem__ = IndexedComponent.__getitem__ # type: ignore
2272+
22542273

22552274
#
22562275
# Deprecated functions.

Diff for: pyomo/core/base/constraint.py

+19
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import sys
1314
import logging
1415
from weakref import ref as weakref_ref
1516
from pyomo.common.pyomo_typing import overload
17+
from typing import Union, Type
1618

1719
from pyomo.common.deprecation import RenamedClass
1820
from pyomo.common.errors import DeveloperError
@@ -42,6 +44,7 @@
4244
ActiveIndexedComponent,
4345
UnindexedComponent_set,
4446
rule_wrapper,
47+
IndexedComponent,
4548
)
4649
from pyomo.core.base.set import Set
4750
from pyomo.core.base.disable_methods import disable_methods
@@ -728,6 +731,17 @@ class Infeasible(object):
728731
Violated = Infeasible
729732
Satisfied = Feasible
730733

734+
@overload
735+
def __new__(
736+
cls: Type[Constraint], *args, **kwds
737+
) -> Union[ScalarConstraint, IndexedConstraint]: ...
738+
739+
@overload
740+
def __new__(cls: Type[ScalarConstraint], *args, **kwds) -> ScalarConstraint: ...
741+
742+
@overload
743+
def __new__(cls: Type[IndexedConstraint], *args, **kwds) -> IndexedConstraint: ...
744+
731745
def __new__(cls, *args, **kwds):
732746
if cls != Constraint:
733747
return super(Constraint, cls).__new__(cls)
@@ -1020,6 +1034,11 @@ def add(self, index, expr):
10201034
"""Add a constraint with a given index."""
10211035
return self.__setitem__(index, expr)
10221036

1037+
@overload
1038+
def __getitem__(self, index) -> _GeneralConstraintData: ...
1039+
1040+
__getitem__ = IndexedComponent.__getitem__ # type: ignore
1041+
10231042

10241043
@ModelComponentFactory.register("A list of constraint expressions.")
10251044
class ConstraintList(IndexedConstraint):

Diff for: pyomo/core/base/indexed_component.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pyomo.core.base as BASE
1919
from pyomo.core.base.indexed_component_slice import IndexedComponent_slice
2020
from pyomo.core.base.initializer import Initializer
21-
from pyomo.core.base.component import Component, ActiveComponent
21+
from pyomo.core.base.component import Component, ActiveComponent, ComponentData
2222
from pyomo.core.base.config import PyomoOptions
2323
from pyomo.core.base.enums import SortComponents
2424
from pyomo.core.base.global_set import UnindexedComponent_set
@@ -606,7 +606,7 @@ def iteritems(self):
606606
"""Return a list (index,data) tuples from the dictionary"""
607607
return self.items()
608608

609-
def __getitem__(self, index):
609+
def __getitem__(self, index) -> ComponentData:
610610
"""
611611
This method returns the data corresponding to the given index.
612612
"""

Diff for: pyomo/core/base/param.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import sys
1314
import types
1415
import logging
1516
from weakref import ref as weakref_ref
1617
from pyomo.common.pyomo_typing import overload
18+
from typing import Union, Type
1719

1820
from pyomo.common.autoslots import AutoSlots
1921
from pyomo.common.deprecation import deprecation_warning, RenamedClass
@@ -291,6 +293,17 @@ class NoValue(object):
291293

292294
pass
293295

296+
@overload
297+
def __new__(
298+
cls: Type[Param], *args, **kwds
299+
) -> Union[ScalarParam, IndexedParam]: ...
300+
301+
@overload
302+
def __new__(cls: Type[ScalarParam], *args, **kwds) -> ScalarParam: ...
303+
304+
@overload
305+
def __new__(cls: Type[IndexedParam], *args, **kwds) -> IndexedParam: ...
306+
294307
def __new__(cls, *args, **kwds):
295308
if cls != Param:
296309
return super(Param, cls).__new__(cls)
@@ -983,7 +996,7 @@ def _create_objects_for_deepcopy(self, memo, component_list):
983996
# between potentially variable GetItemExpression objects and
984997
# "constant" GetItemExpression objects. That will need to wait for
985998
# the expression rework [JDS; Nov 22].
986-
def __getitem__(self, args):
999+
def __getitem__(self, args) -> _ParamData:
9871000
try:
9881001
return super().__getitem__(args)
9891002
except:

Diff for: pyomo/core/base/set.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import inspect
1314
import itertools
1415
import logging
1516
import math
1617
import sys
1718
import weakref
1819
from pyomo.common.pyomo_typing import overload
20+
from typing import Union, Type, Any as typingAny
21+
from collections.abc import Iterator
1922

2023
from pyomo.common.collections import ComponentSet
2124
from pyomo.common.deprecation import deprecated, deprecation_warning, RenamedClass
@@ -569,7 +572,7 @@ def isordered(self):
569572
def subsets(self, expand_all_set_operators=None):
570573
return iter((self,))
571574

572-
def __iter__(self):
575+
def __iter__(self) -> Iterator[typingAny]:
573576
"""Iterate over the set members
574577
575578
Raises AttributeError for non-finite sets. This must be
@@ -1967,6 +1970,12 @@ class SortedOrder(object):
19671970
_ValidOrderedAuguments = {True, False, InsertionOrder, SortedOrder}
19681971
_UnorderedInitializers = {set}
19691972

1973+
@overload
1974+
def __new__(cls: Type[Set], *args, **kwds) -> Union[_SetData, IndexedSet]: ...
1975+
1976+
@overload
1977+
def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ...
1978+
19701979
def __new__(cls, *args, **kwds):
19711980
if cls is not Set:
19721981
return super(Set, cls).__new__(cls)
@@ -2373,6 +2382,11 @@ def data(self):
23732382
"Return a dict containing the data() of each Set in this IndexedSet"
23742383
return {k: v.data() for k, v in self.items()}
23752384

2385+
@overload
2386+
def __getitem__(self, index) -> _SetData: ...
2387+
2388+
__getitem__ = IndexedComponent.__getitem__ # type: ignore
2389+
23762390

23772391
class FiniteScalarSet(_FiniteSetData, Set):
23782392
def __init__(self, **kwds):

Diff for: pyomo/core/base/var.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12+
from __future__ import annotations
1213
import logging
1314
import sys
1415
from pyomo.common.pyomo_typing import overload
1516
from weakref import ref as weakref_ref
17+
from typing import Union, Type
1618

1719
from pyomo.common.deprecation import RenamedClass
1820
from pyomo.common.log import is_debug_set
@@ -668,6 +670,15 @@ class Var(IndexedComponent, IndexedComponent_NDArrayMixin):
668670

669671
_ComponentDataClass = _GeneralVarData
670672

673+
@overload
674+
def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ...
675+
676+
@overload
677+
def __new__(cls: Type[ScalarVar], *args, **kwargs) -> ScalarVar: ...
678+
679+
@overload
680+
def __new__(cls: Type[IndexedVar], *args, **kwargs) -> IndexedVar: ...
681+
671682
def __new__(cls, *args, **kwargs):
672683
if cls is not Var:
673684
return super(Var, cls).__new__(cls)
@@ -688,7 +699,7 @@ def __init__(
688699
dense=True,
689700
units=None,
690701
name=None,
691-
doc=None
702+
doc=None,
692703
): ...
693704

694705
def __init__(self, *args, **kwargs):
@@ -1046,7 +1057,7 @@ def domain(self, domain):
10461057
# between potentially variable GetItemExpression objects and
10471058
# "constant" GetItemExpression objects. That will need to wait for
10481059
# the expression rework [JDS; Nov 22].
1049-
def __getitem__(self, args):
1060+
def __getitem__(self, args) -> _GeneralVarData:
10501061
try:
10511062
return super().__getitem__(args)
10521063
except RuntimeError:

0 commit comments

Comments
 (0)