Skip to content

Commit dc72d1d

Browse files
authored
Merge pull request Pyomo#3368 from jsiirola/set-filter-fix
Resolve issue in filter/validate deprecation path
2 parents 459f8e8 + 0fff1e7 commit dc72d1d

File tree

2 files changed

+60
-22
lines changed

2 files changed

+60
-22
lines changed

pyomo/core/base/set.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -1484,18 +1484,7 @@ def _cb_validate_filter(self, mode, val_iter):
14841484
try:
14851485
flag = fcn(block, (), *vstar)
14861486
if flag:
1487-
deprecation_warning(
1488-
f"{self.__class__.__name__} {self.name}: '{mode}=' "
1489-
"callback signature matched (block, *value). "
1490-
"Please update the callback to match the signature "
1491-
f"(block, value{', *index' if comp.is_indexed() else ''}).",
1492-
version='6.8.0',
1493-
)
1494-
orig_fcn = fcn._fcn
1495-
fcn = ParameterizedScalarCallInitializer(
1496-
lambda m, v: orig_fcn(m, *v), True
1497-
)
1498-
setattr(comp, '_' + mode, fcn)
1487+
self._filter_validate_scalar_api_deprecation(mode, warning=True)
14991488
yield value
15001489
continue
15011490
except TypeError:
@@ -1536,6 +1525,21 @@ def _cb_validate_filter(self, mode, val_iter):
15361525
)
15371526
raise exc from None
15381527

1528+
def _filter_validate_scalar_api_deprecation(self, mode, warning):
1529+
comp = self.parent_component()
1530+
fcn = getattr(comp, '_' + mode)
1531+
if warning:
1532+
deprecation_warning(
1533+
f"{self.__class__.__name__} {self.name}: '{mode}=' "
1534+
"callback signature matched (block, *value). "
1535+
"Please update the callback to match the signature "
1536+
f"(block, value{', *index' if comp.is_indexed() else ''}).",
1537+
version='6.8.0',
1538+
)
1539+
orig_fcn = fcn._fcn
1540+
fcn = ParameterizedScalarCallInitializer(lambda m, v: orig_fcn(m, *v), True)
1541+
setattr(comp, '_' + mode, fcn)
1542+
15391543
def _cb_normalized_dimen_verifier(self, dimen, val_iter):
15401544
for value in val_iter:
15411545
if value.__class__ in native_types:
@@ -2256,14 +2260,20 @@ def __init__(self, *args, **kwds):
22562260
self._init_values._init = CountedCallInitializer(
22572261
self, self._init_values._init
22582262
)
2259-
# HACK: the DAT parser needs to know the domain of a set in
2260-
# order to correctly parse the data stream.
2263+
22612264
if not self.is_indexed():
2265+
# HACK: the DAT parser needs to know the domain of a set in
2266+
# order to correctly parse the data stream.
22622267
if self._init_domain.constant():
22632268
self._domain = self._init_domain(self.parent_block(), None, self)
22642269
if self._init_dimen.constant():
22652270
self._dimen = self._init_dimen(self.parent_block(), None)
22662271

2272+
if self._filter.__class__ is ParameterizedIndexedCallInitializer:
2273+
self._filter_validate_scalar_api_deprecation('filter', warning=False)
2274+
if self._validate.__class__ is ParameterizedIndexedCallInitializer:
2275+
self._filter_validate_scalar_api_deprecation('validate', warning=False)
2276+
22672277
@deprecated(
22682278
"check_values() is deprecated: Sets only contain valid members", version='5.7'
22692279
)

pyomo/core/tests/unit/test_set.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -4181,6 +4181,19 @@ def test_indexed_set(self):
41814181
self.assertIs(type(m.I[3]), InsertionOrderSetData)
41824182
self.assertEqual(m.I.data(), {1: (4, 2, 5), 2: (4, 2, 5), 3: (4, 2, 5)})
41834183

4184+
# Explicit (constant dict) construction
4185+
m = ConcreteModel()
4186+
m.I = Set([1, 2], initialize={1: (4, 2, 5), 2: (7, 6)})
4187+
self.assertEqual(len(m.I), 2)
4188+
self.assertEqual(list(m.I[1]), [4, 2, 5])
4189+
self.assertEqual(list(m.I[2]), [7, 6])
4190+
self.assertIsNot(m.I[1], m.I[2])
4191+
self.assertTrue(m.I[1].isordered())
4192+
self.assertTrue(m.I[2].isordered())
4193+
self.assertIs(type(m.I[1]), InsertionOrderSetData)
4194+
self.assertIs(type(m.I[2]), InsertionOrderSetData)
4195+
self.assertEqual(m.I.data(), {1: (4, 2, 5), 2: (7, 6)})
4196+
41844197
# Explicit (constant) construction
41854198
m = ConcreteModel()
41864199
m.I = Set([1, 2, 3], initialize=(4, 2, 5), ordered=Set.SortedOrder)
@@ -4255,7 +4268,7 @@ def test_indexing(self):
42554268
def test_add_filter_validate(self):
42564269
m = ConcreteModel()
42574270
m.I = Set(domain=Integers)
4258-
self.assertIs(m.I.filter, None)
4271+
self.assertIs(m.I._filter, None)
42594272
with self.assertRaisesRegex(
42604273
ValueError,
42614274
r"Cannot add value 1.5 to Set I.\n"
@@ -4302,7 +4315,7 @@ def _l_tri(model, i, j):
43024315
return i >= j
43034316

43044317
m.K = Set(initialize=RangeSet(3) * RangeSet(3), filter=_l_tri)
4305-
self.assertIsInstance(m.K.filter, ParameterizedScalarCallInitializer)
4318+
self.assertIsInstance(m.K._filter, ParameterizedScalarCallInitializer)
43064319
self.assertEqual(list(m.K), [(1, 1), (2, 1), (2, 2), (3, 1), (3, 2), (3, 3)])
43074320

43084321
output = StringIO()
@@ -4334,6 +4347,18 @@ def _lt_3(model, i):
43344347
self.assertEqual(output.getvalue(), "")
43354348
self.assertEqual(list(m.L[2]), [1, 2, 0])
43364349

4350+
# This tests that the deprecation path works correctly in the
4351+
# case that the callback doesn't raise an error or ever return
4352+
# False
4353+
4354+
def _l_off_diag(model, i, j):
4355+
self.assertIs(model, m)
4356+
return i != j
4357+
4358+
m.M = Set(initialize=RangeSet(3) * RangeSet(3), filter=_l_off_diag)
4359+
self.assertIsInstance(m.M._filter, ParameterizedScalarCallInitializer)
4360+
self.assertEqual(list(m.M), [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)])
4361+
43374362
m = ConcreteModel()
43384363

43394364
def _validate(model, val):
@@ -4374,12 +4399,15 @@ def _validate(model, i, j):
43744399
m.I2 = Set(validate=_validate)
43754400
with LoggingIntercept(module='pyomo.core') as output:
43764401
self.assertTrue(m.I2.add((0, 1)))
4377-
self.assertRegex(
4378-
output.getvalue().replace('\n', ' '),
4379-
r"DEPRECATED: OrderedScalarSet I2: 'validate=' callback "
4380-
r"signature matched \(block, \*value\). Please update the "
4381-
r"callback to match the signature \(block, value\)",
4382-
)
4402+
# Note that we are not emitting a deprecation warning (yet)
4403+
# for scalar sets
4404+
# self.assertEqual(output.getvalue(), "")
4405+
# output.getvalue().replace('\n', ' '),
4406+
# r"DEPRECATED: OrderedScalarSet I2: 'validate=' callback "
4407+
# r"signature matched \(block, \*value\). Please update the "
4408+
# r"callback to match the signature \(block, value\)",
4409+
# )
4410+
self.assertEqual(output.getvalue(), "")
43834411
with LoggingIntercept(module='pyomo.core') as output:
43844412
with self.assertRaisesRegex(
43854413
ValueError,

0 commit comments

Comments
 (0)