17
17
18
18
from pyomo .core .base .constraint import Constraint
19
19
from pyomo .core .base .expression import Expression
20
+ from pyomo .core .base .objective import Objective
20
21
from pyomo .core .base .external import ExternalFunction
21
22
from pyomo .core .expr .visitor import StreamBasedExpressionVisitor
22
23
from pyomo .core .expr .numeric_expr import ExternalFunctionExpression
23
- from pyomo .core .expr .numvalue import native_types
24
+ from pyomo .core .expr .numvalue import native_types , NumericValue
24
25
25
26
26
27
class _ExternalFunctionVisitor (StreamBasedExpressionVisitor ):
28
+ def __init__ (self , descend_into_named_expressions = True ):
29
+ super ().__init__ ()
30
+ self ._descend_into_named_expressions = descend_into_named_expressions
31
+ self .named_expressions = []
32
+
27
33
def initializeWalker (self , expr ):
28
34
self ._functions = []
29
35
self ._seen = set ()
30
36
return True , None
31
37
38
+ def beforeChild (self , parent , child , index ):
39
+ if child .__class__ in native_types :
40
+ return False , None
41
+ elif (
42
+ not self ._descend_into_named_expressions
43
+ and child .is_named_expression_type ()
44
+ ):
45
+ self .named_expressions .append (child )
46
+ return False , None
47
+ return True , None
48
+
32
49
def exitNode (self , node , data ):
33
50
if type (node ) is ExternalFunctionExpression :
34
51
if id (node ) not in self ._seen :
@@ -38,26 +55,35 @@ def exitNode(self, node, data):
38
55
def finalizeResult (self , result ):
39
56
return self ._functions
40
57
41
- def enterNode (self , node ):
42
- pass
43
-
44
- def acceptChildResult (self , node , data , child_result , child_idx ):
45
- pass
46
-
47
- def acceptChildResult (self , node , data , child_result , child_idx ):
48
- if child_result .__class__ in native_types :
49
- return False , None
50
- return child_result .is_expression_type (), None
51
-
52
58
53
59
def identify_external_functions (expr ):
54
60
yield from _ExternalFunctionVisitor ().walk_expression (expr )
55
61
56
62
57
63
def add_local_external_functions (block ):
58
64
ef_exprs = []
59
- for comp in block .component_data_objects ((Constraint , Expression ), active = True ):
60
- ef_exprs .extend (identify_external_functions (comp .expr ))
65
+ named_expressions = []
66
+ visitor = _ExternalFunctionVisitor (descend_into_named_expressions = False )
67
+ for comp in block .component_data_objects (
68
+ (Constraint , Expression , Objective ), active = True
69
+ ):
70
+ ef_exprs .extend (visitor .walk_expression (comp .expr ))
71
+ named_expr_set = ComponentSet (visitor .named_expressions )
72
+ # List of unique named expressions
73
+ named_expressions = list (named_expr_set )
74
+ while named_expressions :
75
+ expr = named_expressions .pop ()
76
+ # Clear named expression cache so we don't re-check named expressions
77
+ # we've seen before.
78
+ visitor .named_expressions .clear ()
79
+ ef_exprs .extend (visitor .walk_expression (expr ))
80
+ # Only add to the stack named expressions that we have
81
+ # not encountered yet.
82
+ for local_expr in visitor .named_expressions :
83
+ if local_expr not in named_expr_set :
84
+ named_expressions .append (local_expr )
85
+ named_expr_set .add (local_expr )
86
+
61
87
unique_functions = []
62
88
fcn_set = set ()
63
89
for expr in ef_exprs :
@@ -148,7 +174,14 @@ class TemporarySubsystemManager(object):
148
174
149
175
"""
150
176
151
- def __init__ (self , to_fix = None , to_deactivate = None , to_reset = None , to_unfix = None ):
177
+ def __init__ (
178
+ self ,
179
+ to_fix = None ,
180
+ to_deactivate = None ,
181
+ to_reset = None ,
182
+ to_unfix = None ,
183
+ remove_bounds_on_fix = False ,
184
+ ):
152
185
"""
153
186
Arguments
154
187
---------
@@ -168,6 +201,8 @@ def __init__(self, to_fix=None, to_deactivate=None, to_reset=None, to_unfix=None
168
201
List of var data objects to be temporarily unfixed. These are
169
202
restored to their original status on exit from this object's
170
203
context manager.
204
+ remove_bounds_on_fix: Bool
205
+ Whether bounds should be removed temporarily for fixed variables
171
206
172
207
"""
173
208
if to_fix is None :
@@ -194,6 +229,8 @@ def __init__(self, to_fix=None, to_deactivate=None, to_reset=None, to_unfix=None
194
229
self ._con_was_active = None
195
230
self ._comp_original_value = None
196
231
self ._var_was_unfixed = None
232
+ self ._remove_bounds_on_fix = remove_bounds_on_fix
233
+ self ._fixed_var_bounds = None
197
234
198
235
def __enter__ (self ):
199
236
to_fix = self ._vars_to_fix
@@ -203,8 +240,13 @@ def __enter__(self):
203
240
self ._var_was_fixed = [(var , var .fixed ) for var in to_fix + to_unfix ]
204
241
self ._con_was_active = [(con , con .active ) for con in to_deactivate ]
205
242
self ._comp_original_value = [(comp , comp .value ) for comp in to_set ]
243
+ self ._fixed_var_bounds = [(var .lb , var .ub ) for var in to_fix ]
206
244
207
245
for var in self ._vars_to_fix :
246
+ if self ._remove_bounds_on_fix :
247
+ # TODO: Potentially override var.domain as well?
248
+ var .setlb (None )
249
+ var .setub (None )
208
250
var .fix ()
209
251
210
252
for con in self ._cons_to_deactivate :
@@ -223,6 +265,11 @@ def __exit__(self, ex_type, ex_val, ex_bt):
223
265
var .fix ()
224
266
else :
225
267
var .unfix ()
268
+ if self ._remove_bounds_on_fix :
269
+ for var , (lb , ub ) in zip (self ._vars_to_fix , self ._fixed_var_bounds ):
270
+ var .setlb (lb )
271
+ var .setub (ub )
272
+
226
273
for con , was_active in self ._con_was_active :
227
274
if was_active :
228
275
con .activate ()
0 commit comments