Skip to content

Commit 648aae7

Browse files
committed
Update type registration tests to reflect automatic numpy callback
1 parent a10d364 commit 648aae7

File tree

1 file changed

+72
-18
lines changed

1 file changed

+72
-18
lines changed

pyomo/core/tests/unit/test_numvalue.py

+72-18
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,16 @@ def __init__(self, val=0):
5050

5151
class MyBogusNumericType(MyBogusType):
5252
def __add__(self, other):
53-
return MyBogusNumericType(self.val + float(other))
53+
if other.__class__ in native_numeric_types:
54+
return MyBogusNumericType(self.val + float(other))
55+
else:
56+
return NotImplemented
57+
58+
def __le__(self, other):
59+
if other.__class__ in native_numeric_types:
60+
return self.val <= float(other)
61+
else:
62+
return NotImplemented
5463

5564
def __lt__(self, other):
5665
return self.val < float(other)
@@ -534,6 +543,8 @@ def test_unknownNumericType(self):
534543
try:
535544
val = as_numeric(ref)
536545
self.assertEqual(val().val, 42.0)
546+
self.assertIn(MyBogusNumericType, native_numeric_types)
547+
self.assertIn(MyBogusNumericType, native_types)
537548
finally:
538549
native_numeric_types.remove(MyBogusNumericType)
539550
native_types.remove(MyBogusNumericType)
@@ -562,10 +573,43 @@ def test_numpy_basic_bool_registration(self):
562573
@unittest.skipUnless(numpy_available, "This test requires NumPy")
563574
def test_automatic_numpy_registration(self):
564575
cmd = (
565-
'import pyomo; from pyomo.core.base import Var, Param; '
566-
'from pyomo.core.base.units_container import units; import numpy as np; '
567-
'print(np.float64 in pyomo.common.numeric_types.native_numeric_types); '
568-
'%s; print(np.float64 in pyomo.common.numeric_types.native_numeric_types)'
576+
'from pyomo.common.numeric_types import native_numeric_types as nnt; '
577+
'print("float64" in [_.__name__ for _ in nnt]); '
578+
'import numpy; '
579+
'print("float64" in [_.__name__ for _ in nnt])'
580+
)
581+
582+
rc = subprocess.run(
583+
[sys.executable, '-c', cmd],
584+
stdout=subprocess.PIPE,
585+
stderr=subprocess.STDOUT,
586+
text=True,
587+
)
588+
self.assertEqual((rc.returncode, rc.stdout), (0, "False\nTrue\n"))
589+
590+
cmd = (
591+
'import numpy; '
592+
'from pyomo.common.numeric_types import native_numeric_types as nnt; '
593+
'print("float64" in [_.__name__ for _ in nnt])'
594+
)
595+
596+
rc = subprocess.run(
597+
[sys.executable, '-c', cmd],
598+
stdout=subprocess.PIPE,
599+
stderr=subprocess.STDOUT,
600+
text=True,
601+
)
602+
self.assertEqual((rc.returncode, rc.stdout), (0, "True\n"))
603+
604+
def test_unknownNumericType_expr_registration(self):
605+
cmd = (
606+
'import pyomo; '
607+
'from pyomo.core.base import Var, Param; '
608+
'from pyomo.core.base.units_container import units; '
609+
'from pyomo.common.numeric_types import native_numeric_types as nnt; '
610+
f'from {__name__} import MyBogusNumericType; '
611+
'ref = MyBogusNumericType(42); '
612+
'print(MyBogusNumericType in nnt); %s; print(MyBogusNumericType in nnt); '
569613
)
570614

571615
def _tester(expr):
@@ -575,19 +619,29 @@ def _tester(expr):
575619
stderr=subprocess.STDOUT,
576620
text=True,
577621
)
578-
self.assertEqual((rc.returncode, rc.stdout), (0, "False\nTrue\n"))
579-
580-
_tester('Var() <= np.float64(5)')
581-
_tester('np.float64(5) <= Var()')
582-
_tester('np.float64(5) + Var()')
583-
_tester('Var() + np.float64(5)')
584-
_tester('v = Var(); v.construct(); v.value = np.float64(5)')
585-
_tester('p = Param(mutable=True); p.construct(); p.value = np.float64(5)')
586-
_tester('v = Var(units=units.m); v.construct(); v.value = np.float64(5)')
587-
_tester(
588-
'p = Param(mutable=True, units=units.m); p.construct(); '
589-
'p.value = np.float64(5)'
590-
)
622+
self.assertEqual(
623+
(rc.returncode, rc.stdout),
624+
(
625+
0,
626+
'''False
627+
WARNING: Dynamically registering the following numeric type:
628+
pyomo.core.tests.unit.test_numvalue.MyBogusNumericType
629+
Dynamic registration is supported for convenience, but there are known
630+
limitations to this approach. We recommend explicitly registering numeric
631+
types using RegisterNumericType() or RegisterIntegerType().
632+
True
633+
''',
634+
),
635+
)
636+
637+
_tester('Var() <= ref')
638+
_tester('ref <= Var()')
639+
_tester('ref + Var()')
640+
_tester('Var() + ref')
641+
_tester('v = Var(); v.construct(); v.value = ref')
642+
_tester('p = Param(mutable=True); p.construct(); p.value = ref')
643+
_tester('v = Var(units=units.m); v.construct(); v.value = ref')
644+
_tester('p = Param(mutable=True, units=units.m); p.construct(); p.value = ref')
591645

592646

593647
if __name__ == "__main__":

0 commit comments

Comments
 (0)