Skip to content

Commit 4a5bf8a

Browse files
authored
Merge pull request #3241 from emma58/fix-domain-bug-in-var-aggregator
contrib.preprocessing: Fixing bug where variable aggregator did not intersect domains
2 parents bd640f8 + 127f8c6 commit 4a5bf8a

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

pyomo/contrib/preprocessing/plugins/var_aggregator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313

1414

1515
from pyomo.common.collections import ComponentMap, ComponentSet
16-
from pyomo.core.base import Block, Constraint, VarList, Objective, TransformationFactory
16+
from pyomo.core.base import (
17+
Block,
18+
Constraint,
19+
VarList,
20+
Objective,
21+
Reals,
22+
TransformationFactory,
23+
)
1724
from pyomo.core.expr import ExpressionReplacementVisitor
1825
from pyomo.core.expr.numvalue import value
1926
from pyomo.core.plugins.transform.hierarchy import IsomorphicTransformation
@@ -248,6 +255,12 @@ def _apply_to(self, model, detect_fixed_vars=True):
248255
# the variables in its equality set.
249256
z_agg.setlb(max_if_not_None(v.lb for v in eq_set if v.has_lb()))
250257
z_agg.setub(min_if_not_None(v.ub for v in eq_set if v.has_ub()))
258+
# Set the domain of the aggregate variable to the intersection of
259+
# the domains of the variables in its equality set
260+
domain = Reals
261+
for v in eq_set:
262+
domain = domain & v.domain
263+
z_agg.domain = domain
251264

252265
# Set the fixed status of the aggregate var
253266
fixed_vars = [v for v in eq_set if v.fixed]

pyomo/contrib/preprocessing/tests/test_var_aggregator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
max_if_not_None,
2020
min_if_not_None,
2121
)
22+
from pyomo.core.expr.compare import assertExpressionsEqual
2223
from pyomo.environ import (
24+
Binary,
2325
ConcreteModel,
2426
Constraint,
2527
ConstraintList,
28+
maximize,
2629
Objective,
2730
RangeSet,
31+
Reals,
2832
SolverFactory,
2933
TransformationFactory,
3034
Var,
@@ -210,6 +214,36 @@ def test_var_update(self):
210214
self.assertEqual(m.x.value, 0)
211215
self.assertEqual(m.y.value, 0)
212216

217+
def test_binary_inequality(self):
218+
m = ConcreteModel()
219+
m.x = Var(domain=Binary)
220+
m.y = Var(domain=Binary)
221+
m.c = Constraint(expr=m.x == m.y)
222+
m.o = Objective(expr=0.5 * m.x + m.y, sense=maximize)
223+
TransformationFactory('contrib.aggregate_vars').apply_to(m)
224+
var_to_z = m._var_aggregator_info.var_to_z
225+
z = var_to_z[m.x]
226+
self.assertIs(var_to_z[m.y], z)
227+
self.assertEqual(z.domain, Binary)
228+
self.assertEqual(z.lb, 0)
229+
self.assertEqual(z.ub, 1)
230+
assertExpressionsEqual(self, m.o.expr, 0.5 * z + z)
231+
232+
def test_equality_different_domains(self):
233+
m = ConcreteModel()
234+
m.x = Var(domain=Reals, bounds=(1, 2))
235+
m.y = Var(domain=Binary)
236+
m.c = Constraint(expr=m.x == m.y)
237+
m.o = Objective(expr=0.5 * m.x + m.y, sense=maximize)
238+
TransformationFactory('contrib.aggregate_vars').apply_to(m)
239+
var_to_z = m._var_aggregator_info.var_to_z
240+
z = var_to_z[m.x]
241+
self.assertIs(var_to_z[m.y], z)
242+
self.assertEqual(z.lb, 1)
243+
self.assertEqual(z.ub, 1)
244+
self.assertEqual(z.domain, Binary)
245+
assertExpressionsEqual(self, m.o.expr, 0.5 * z + z)
246+
213247

214248
if __name__ == '__main__':
215249
unittest.main()

0 commit comments

Comments
 (0)