Skip to content

Commit a9713fe

Browse files
authored
Merge pull request #2840 from jsiirola/lpv2
Follow-up for the LPv2 writer
2 parents e47d6f3 + b94e68f commit a9713fe

File tree

6 files changed

+345
-91
lines changed

6 files changed

+345
-91
lines changed

pyomo/repn/quadratic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from pyomo.core.base.expression import ScalarExpression
3232
from . import linear
33-
from .linear import _merge_dict
33+
from .linear import _merge_dict, to_expression
3434

3535
_CONSTANT = linear.ExprType.CONSTANT
3636
_LINEAR = linear.ExprType.LINEAR
@@ -211,8 +211,8 @@ def _handle_product_linear_linear(visitor, node, arg1, arg2):
211211
def _handle_product_nonlinear(visitor, node, arg1, arg2):
212212
ans = visitor.Result()
213213
if not visitor.expand_nonlinear_products:
214-
ans.nonlinear = arg1.to_expression(visitor) * arg2.to_expression(visitor)
215-
return ans
214+
ans.nonlinear = to_expression(visitor, arg1) * to_expression(visitor, arg2)
215+
return _GENERAL, ans
216216

217217
# We are multiplying (A + Bx + Cx^2 + D(x)) * (A + Bx + Cx^2 + Dx))
218218
_, x1 = arg1

pyomo/repn/tests/diffutils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# ___________________________________________________________________________
2+
#
3+
# Pyomo: Python Optimization Modeling Objects
4+
# Copyright (c) 2008-2022
5+
# National Technology and Engineering Solutions of Sandia, LLC
6+
# Under the terms of Contract DE-NA0003525 with National Technology and
7+
# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
8+
# rights in this software.
9+
# This software is distributed under the 3-clause BSD License.
10+
# ___________________________________________________________________________
11+
12+
import os
13+
14+
import pyomo.core.expr.current as EXPR
15+
16+
17+
def compare_floats(base, test, abstol=1e-14, reltol=1e-14):
18+
base = base.split()
19+
test = test.split()
20+
if len(base) != len(test):
21+
return False
22+
for i, b in enumerate(base):
23+
if b.strip() == test[i].strip():
24+
continue
25+
try:
26+
b = float(b)
27+
t = float(test[i])
28+
except:
29+
return False
30+
if abs(b - t) < abstol:
31+
continue
32+
if abs((b - t) / max(abs(b), abs(t))) < reltol:
33+
continue
34+
return False
35+
return True
36+
37+
38+
def load_baseline(baseline, testfile, extension, version):
39+
with open(testfile, 'r') as FILE:
40+
test = FILE.read()
41+
if baseline.endswith(f'.{extension}'):
42+
_tmp = [baseline[:-3]]
43+
else:
44+
_tmp = baseline.split(f'.{extension}.', 1)
45+
_tmp.insert(1, f'expr{int(EXPR._mode)}')
46+
_tmp.insert(2, version)
47+
if not os.path.exists('.'.join(_tmp)):
48+
_tmp.pop(1)
49+
if not os.path.exists('.'.join(_tmp)):
50+
_tmp = []
51+
if _tmp:
52+
baseline = '.'.join(_tmp)
53+
with open(baseline, 'r') as FILE:
54+
base = FILE.read()
55+
return base, test, baseline, testfile

pyomo/repn/tests/lp_diff.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,18 @@
99
# This software is distributed under the 3-clause BSD License.
1010
# ___________________________________________________________________________
1111

12-
import os
1312
import re
1413

1514
from difflib import SequenceMatcher, unified_diff
1615

17-
import pyomo.core.expr.current as EXPR
16+
from pyomo.repn.tests.diffutils import compare_floats, load_baseline
1817

1918
_strip_comment = re.compile(r'\s*\\.*')
2019

2120

22-
def _compare_floats(base, test, abstol=1e-14, reltol=1e-14):
23-
base = base.split()
24-
test = test.split()
25-
if len(base) != len(test):
26-
return False
27-
for i, b in enumerate(base):
28-
if b.strip() == test[i].strip():
29-
continue
30-
try:
31-
b = float(b)
32-
t = float(test[i])
33-
except:
34-
return False
35-
if abs(b - t) < abstol:
36-
continue
37-
if abs((b - t) / max(abs(b), abs(t))) < reltol:
38-
continue
39-
return False
40-
return True
41-
42-
4321
def _update_subsets(subset, base, test):
4422
for i, j in zip(*subset):
45-
if _compare_floats(base[i], test[j]):
23+
if compare_floats(base[i], test[j]):
4624
base[i] = test[j]
4725

4826

@@ -90,23 +68,7 @@ def lp_diff(base, test, baseline='baseline', testfile='testfile'):
9068

9169

9270
def load_lp_baseline(baseline, testfile, version='lp'):
93-
with open(testfile, 'r') as FILE:
94-
test = FILE.read()
95-
if baseline.endswith('.lp'):
96-
_tmp = [baseline[:-3]]
97-
else:
98-
_tmp = baseline.split('.lp.', 1)
99-
_tmp.insert(1, f'expr{int(EXPR._mode)}')
100-
_tmp.insert(2, version)
101-
if not os.path.exists('.'.join(_tmp)):
102-
_tmp.pop(1)
103-
if not os.path.exists('.'.join(_tmp)):
104-
_tmp = []
105-
if _tmp:
106-
baseline = '.'.join(_tmp)
107-
with open(baseline, 'r') as FILE:
108-
base = FILE.read()
109-
return base, test, baseline, testfile
71+
return load_baseline(baseline, testfile, 'lp', version)
11072

11173

11274
def load_and_compare_lp_baseline(baseline, testfile, version='lp'):

pyomo/repn/tests/nl_diff.py

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
# ___________________________________________________________________________
1111

1212
import itertools
13-
import os
1413
import re
1514

1615
from difflib import SequenceMatcher, unified_diff
1716

18-
import pyomo.core.expr.current as EXPR
17+
from pyomo.repn.tests.diffutils import compare_floats, load_baseline
1918
import pyomo.repn.plugins.nl_writer as nl_writer
2019

2120
template = nl_writer.text_nl_debug_template
@@ -29,40 +28,19 @@
2928
_norm_double_negation = re.compile(r'(?m)^o16(\s*#\s*-)?\no16(\s*#\s*-)?\n')
3029

3130

32-
def _compare_floats(base, test, abstol=1e-14, reltol=1e-14):
33-
base = base.split()
34-
test = test.split()
35-
if len(base) != len(test):
36-
return False
37-
for i, b in enumerate(base):
38-
if b == test[i]:
39-
continue
40-
try:
41-
b = float(b)
42-
t = float(test[i])
43-
except:
44-
return False
45-
if abs(b - t) < abstol:
46-
continue
47-
if abs((b - t) / max(abs(b), abs(t))) < reltol:
48-
continue
49-
return False
50-
return True
51-
52-
5331
def _update_subsets(subset, base, test):
5432
for i, j in zip(*subset):
5533
# Try checking for numbers
5634
if base[i][0] == 'n' and test[j][0] == 'n':
57-
if _compare_floats(base[i][1:], test[j][1:]):
35+
if compare_floats(base[i][1:], test[j][1:]):
5836
test[j] = base[i]
59-
elif _compare_floats(base[i], test[j]):
37+
elif compare_floats(base[i], test[j]):
6038
test[j] = base[i]
6139
else:
6240
# try stripping comments, but only if it results in a match
6341
base_nc = _strip_comment.sub('', base[i])
6442
test_nc = _strip_comment.sub('', test[j])
65-
if _compare_floats(base_nc, test_nc):
43+
if compare_floats(base_nc, test_nc):
6644
if len(base_nc) > len(test_nc):
6745
test[j] = base[i]
6846
else:
@@ -107,7 +85,7 @@ def nl_diff(base, test, baseline='baseline', testfile='testfile'):
10785
base_nlines = list(x for x in enumerate(base) if x[1] and x[1][0] == 'n')
10886
if len(test_nlines) == len(base_nlines):
10987
for t_line, b_line in zip(test_nlines, base_nlines):
110-
if _compare_floats(t_line[1][1:], b_line[1][1:]):
88+
if compare_floats(t_line[1][1:], b_line[1][1:]):
11189
test[t_line[0]] = base[b_line[0]]
11290

11391
for group in SequenceMatcher(None, base, test).get_grouped_opcodes(3):
@@ -133,23 +111,7 @@ def nl_diff(base, test, baseline='baseline', testfile='testfile'):
133111

134112

135113
def load_nl_baseline(baseline, testfile, version='nl'):
136-
with open(testfile, 'r') as FILE:
137-
test = FILE.read()
138-
if baseline.endswith('.nl'):
139-
_tmp = [baseline[:-3]]
140-
else:
141-
_tmp = baseline.split('.nl.', 1)
142-
_tmp.insert(1, f'expr{int(EXPR._mode)}')
143-
_tmp.insert(2, version)
144-
if not os.path.exists('.'.join(_tmp)):
145-
_tmp.pop(1)
146-
if not os.path.exists('.'.join(_tmp)):
147-
_tmp = []
148-
if _tmp:
149-
baseline = '.'.join(_tmp)
150-
with open(baseline, 'r') as FILE:
151-
base = FILE.read()
152-
return base, test, baseline, testfile
114+
return load_baseline(baseline, testfile, 'nl', version)
153115

154116

155117
def load_and_compare_nl_baseline(baseline, testfile, version='nl'):

0 commit comments

Comments
 (0)