Skip to content

Commit 3f21a2d

Browse files
authored
Merge pull request #3423 from jsiirola/numpy-scalar
Resolve errors in mapping ScalarVar to numpy ndarray
2 parents 5669dfb + 05bdaf7 commit 3f21a2d

File tree

3 files changed

+123
-18
lines changed

3 files changed

+123
-18
lines changed

pyomo/core/base/indexed_component.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,7 @@ def __array__(self, dtype=None):
12011201
if not self.is_indexed():
12021202
ans = _ndarray.NumericNDArray(shape=(1,), dtype=object)
12031203
ans[0] = self
1204-
return ans
1204+
return ans.reshape(())
12051205

12061206
_dim = self.dim()
12071207
if _dim is None:

pyomo/core/expr/compare.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def handle_external_function_expression(node: ExternalFunctionExpression, pn: Li
6666
return node.args
6767

6868

69+
def handle_sequence(node: collections.abc.Sequence, pn: List):
70+
pn.append((collections.abc.Sequence, len(node)))
71+
return list(node)
72+
73+
6974
def _generic_expression_handler():
7075
return handle_expression
7176

@@ -79,6 +84,7 @@ def _generic_expression_handler():
7984
handler[AbsExpression] = handle_unary_expression
8085
handler[NPV_AbsExpression] = handle_unary_expression
8186
handler[RangedExpression] = handle_expression
87+
handler[list] = handle_sequence
8288

8389

8490
class PrefixVisitor(StreamBasedExpressionVisitor):
@@ -97,19 +103,26 @@ def enterNode(self, node):
97103
self._result.append(node)
98104
return tuple(), None
99105

100-
if node.is_expression_type():
101-
if node.is_named_expression_type():
102-
return (
103-
handle_named_expression(
104-
node, self._result, self._include_named_exprs
105-
),
106-
None,
107-
)
108-
else:
109-
return handler[ntype](node, self._result), None
110-
else:
111-
self._result.append(node)
112-
return tuple(), None
106+
if ntype in handler:
107+
return handler[ntype](node, self._result), None
108+
109+
if hasattr(node, 'is_expression_type'):
110+
if node.is_expression_type():
111+
if node.is_named_expression_type():
112+
return (
113+
handle_named_expression(
114+
node, self._result, self._include_named_exprs
115+
),
116+
None,
117+
)
118+
else:
119+
return handler[ntype](node, self._result), None
120+
elif hasattr(node, '__len__'):
121+
handler[ntype] = handle_sequence
122+
return handle_sequence(node, self._result), None
123+
124+
self._result.append(node)
125+
return tuple(), None
113126

114127
def finalizeResult(self, result):
115128
ans = self._result
@@ -161,10 +174,7 @@ def convert_expression_to_prefix_notation(expr, include_named_exprs=True):
161174
162175
"""
163176
visitor = PrefixVisitor(include_named_exprs=include_named_exprs)
164-
if isinstance(expr, Sequence):
165-
return expr.__class__(visitor.walk_expression(e) for e in expr)
166-
else:
167-
return visitor.walk_expression(expr)
177+
return visitor.walk_expression(expr)
168178

169179

170180
def compare_expressions(expr1, expr2, include_named_exprs=True):
+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# ___________________________________________________________________________
2+
#
3+
# Pyomo: Python Optimization Modeling Objects
4+
# Copyright (c) 2008-2024
5+
# National Technology and Engineering Solutions of Sandia, LLC
6+
# Under the terms of Contract DE-NA0003525 with National Technology and
7+
# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
8+
# rights in this software.
9+
# This software is distributed under the 3-clause BSD License.
10+
# ___________________________________________________________________________
11+
12+
import pyomo.common.unittest as unittest
13+
14+
from pyomo.common.dependencies import numpy as np, numpy_available
15+
from pyomo.environ import ConcreteModel, Var, Constraint
16+
17+
18+
@unittest.skipUnless(numpy_available, "tests require numpy")
19+
class TestNumpyExpr(unittest.TestCase):
20+
def test_scalar_operations(self):
21+
m = ConcreteModel()
22+
m.x = Var()
23+
24+
a = np.array(m.x)
25+
self.assertEqual(a.shape, ())
26+
27+
self.assertExpressionsEqual(5 * a, 5 * m.x)
28+
self.assertExpressionsEqual(np.array([2, 3]) * a, [2 * m.x, 3 * m.x])
29+
self.assertExpressionsEqual(np.array([5, 6]) * m.x, [5 * m.x, 6 * m.x])
30+
self.assertExpressionsEqual(np.array([8, m.x]) * m.x, [8 * m.x, m.x * m.x])
31+
32+
a = np.array([m.x])
33+
self.assertEqual(a.shape, (1,))
34+
35+
self.assertExpressionsEqual(5 * a, [5 * m.x])
36+
self.assertExpressionsEqual(np.array([2, 3]) * a, [2 * m.x, 3 * m.x])
37+
self.assertExpressionsEqual(np.array([5, 6]) * m.x, [5 * m.x, 6 * m.x])
38+
self.assertExpressionsEqual(np.array([8, m.x]) * m.x, [8 * m.x, m.x * m.x])
39+
40+
def test_vector_operations(self):
41+
m = ConcreteModel()
42+
m.x = Var()
43+
m.y = Var([0, 1, 2])
44+
45+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
46+
# TODO: when we finally support a true matrix expression
47+
# system, this test should work
48+
self.assertExpressionsEqual(5 * m.y, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
49+
50+
a = np.array(5)
51+
self.assertExpressionsEqual(a * m.y, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
52+
self.assertExpressionsEqual(m.y * a, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
53+
a = np.array([5])
54+
self.assertExpressionsEqual(a * m.y, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
55+
self.assertExpressionsEqual(m.y * a, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
56+
57+
a = np.array(5)
58+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
59+
# TODO: when we finally support a true matrix expression
60+
# system, this test should work
61+
self.assertExpressionsEqual(
62+
a * m.x * m.y, [5 * m.x * m.y[0], 5 * m.x * m.y[1], 5 * m.x * m.y[2]]
63+
)
64+
self.assertExpressionsEqual(
65+
a * m.y * m.x, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
66+
)
67+
self.assertExpressionsEqual(
68+
a * m.y * m.y,
69+
[5 * m.y[0] * m.y[0], 5 * m.y[1] * m.y[1], 5 * m.y[2] * m.y[2]],
70+
)
71+
self.assertExpressionsEqual(
72+
m.y * a * m.x, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
73+
)
74+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
75+
# TODO: when we finally support a true matrix expression
76+
# system, this test should work
77+
self.assertExpressionsEqual(
78+
m.y * m.x * a, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
79+
)
80+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
81+
# TODO: when we finally support a true matrix expression
82+
# system, this test should work
83+
self.assertExpressionsEqual(
84+
m.x * a * m.y, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
85+
)
86+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
87+
# TODO: when we finally support a true matrix expression
88+
# system, this test should work
89+
self.assertExpressionsEqual(
90+
m.x * m.y * a, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
unittest.main()

0 commit comments

Comments
 (0)