Skip to content

Commit 2f859b1

Browse files
assert_solution_equivalent
1 parent f52e707 commit 2f859b1

File tree

2 files changed

+714
-3
lines changed

2 files changed

+714
-3
lines changed

idaes/core/util/testing.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
# pylint: disable=missing-function-docstring
2020
# pylint: disable=protected-access
2121

22-
__author__ = "Andrew Lee"
23-
22+
__author__ = "Andrew Lee, Douglas Allan"
2423

24+
from math import ceil
2525
import os
26+
import textwrap
2627
from typing import Callable, Union
2728

28-
from pyomo.environ import Constraint, Set, units, Var
29+
import pytest
30+
31+
from pyomo.environ import Constraint, Expression, log10, Set, units, Var, value
2932
from pyomo.common.config import ConfigBlock
3033
from pyomo.common import Executable
3134
from pyomo.common.dependencies import attempt_import
@@ -513,3 +516,124 @@ def remove_from_path():
513516
Executable.rehash()
514517

515518
return remove_from_path
519+
520+
521+
def assert_solution_equivalent(blk, expected_results):
522+
"""
523+
Method to iterate through a structured dictionary of variables/expressions, values,
524+
and indices to determine whether the variables have their expected values within the
525+
specified tolerances. The variables that do not have their expected values are
526+
collected and displayed for the user, and an AssertionError is raised.
527+
528+
This method is better than just writing a long series of assert statements because
529+
it shows *all* variables/expressions that do not have the expected values in a single
530+
report, rather than having to change the value in one assert statement, run the test
531+
again, change the next value, run the test again, etc.
532+
533+
This function was partially generated by AI.
534+
535+
Args:
536+
blk: Pyomo block on which variables/expressions being tested are located
537+
expected_results: Dictionary of the form:
538+
{
539+
indexed_var_name: {
540+
index_1: (value, rel_tol, abs_tol),
541+
index_2: (value, rel_tol, abs_tol),
542+
...
543+
}
544+
unindexed_var_name: {
545+
# Unindexed vars pass None as the index
546+
None: (value, rel_tol, abs_tol)
547+
}
548+
...
549+
}
550+
"""
551+
552+
n_failures = 0
553+
failures = []
554+
555+
for name, expected_values_dict in expected_results.items():
556+
recorded_var = False
557+
obj = blk.find_component(name)
558+
if obj is None:
559+
blk_name = blk.name
560+
# Pyomo ConcreteModels are named "unknown" by default
561+
# but seeing "unknown" show up in an error message is confusing.
562+
if blk_name == "unknown":
563+
blk_name = "model"
564+
failure_msg = f" - Could not find object {name} on {blk_name}\n"
565+
failures.append(failure_msg)
566+
continue
567+
568+
obj_type = None
569+
if isinstance(obj, Var):
570+
obj_type = "Variable"
571+
elif isinstance(obj, Expression):
572+
obj_type = "Expression"
573+
else:
574+
failure_msg = f" - Error: object {name} is not a Var or Expression\n"
575+
failures.append(failure_msg)
576+
continue
577+
578+
for index, (expected_value, rel, abs) in expected_values_dict.items():
579+
absent_index = False
580+
is_close = False
581+
if index is None:
582+
component_to_test = obj
583+
else:
584+
if index in obj:
585+
component_to_test = obj[index]
586+
else:
587+
absent_index = True
588+
if not absent_index:
589+
actual_value = value(component_to_test)
590+
591+
# Determine if the values are approximately equal
592+
if actual_value == pytest.approx(expected_value, rel=rel, abs=abs):
593+
is_close = True
594+
if (absent_index or not is_close) and not recorded_var:
595+
failures.append(f" - {obj_type}: {name}")
596+
recorded_var = True
597+
if absent_index:
598+
failure_msg = f" Index: {index} is absent"
599+
failures.append(failure_msg)
600+
n_failures += 1
601+
continue
602+
603+
# If the comparison fails, record the details
604+
if not is_close:
605+
if rel is not None:
606+
n_sig_figs = ceil(-log10(rel)) + 1
607+
format_spec = "." + str(n_sig_figs) + "e"
608+
elif abs is not None:
609+
n_sig_figs = ceil(-log10(abs)) + 1
610+
format_spec = "." + str(n_sig_figs) + "f"
611+
else:
612+
format_spec = ".7e"
613+
failure_msg = (
614+
f" Index: {index}\n"
615+
f" Expected: {expected_value:{format_spec}}\n"
616+
f" Actual: {actual_value:{format_spec}}"
617+
)
618+
failures.append(failure_msg)
619+
n_failures += 1
620+
621+
if recorded_var:
622+
# Extra space between variables
623+
failures[-1] = failures[-1] + "\n"
624+
625+
# --- Final Assertion and Report Generation ---
626+
if len(failures) > 0:
627+
# Construct the final report header
628+
report_header = textwrap.dedent(f"""
629+
=========================== Test Value Mismatches ============================
630+
Found {n_failures} mismatch(es) between expected and actual model values.
631+
Please review the values below and update the test suite if necessary.
632+
==============================================================================
633+
""")
634+
635+
# Combine the header and all failure messages
636+
full_report = report_header + "\n\n" + "\n\n".join(failures)
637+
638+
# Raise a single AssertionError with the complete report
639+
raise AssertionError(full_report)

0 commit comments

Comments
 (0)