|
19 | 19 | # pylint: disable=missing-function-docstring |
20 | 20 | # pylint: disable=protected-access |
21 | 21 |
|
22 | | -__author__ = "Andrew Lee" |
23 | | - |
| 22 | +__author__ = "Andrew Lee, Douglas Allan" |
24 | 23 |
|
| 24 | +from math import ceil |
25 | 25 | import os |
| 26 | +import textwrap |
26 | 27 | from typing import Callable, Union |
27 | 28 |
|
28 | | -from pyomo.environ import Constraint, Set, units, Var |
| 29 | +import pytest |
| 30 | + |
| 31 | +from pyomo.environ import Constraint, Expression, log10, Set, units, Var, value |
29 | 32 | from pyomo.common.config import ConfigBlock |
30 | 33 | from pyomo.common import Executable |
31 | 34 | from pyomo.common.dependencies import attempt_import |
@@ -513,3 +516,124 @@ def remove_from_path(): |
513 | 516 | Executable.rehash() |
514 | 517 |
|
515 | 518 | return remove_from_path |
| 519 | + |
| 520 | + |
| 521 | +def assert_solution_equivalent(blk, expected_results): |
| 522 | + """ |
| 523 | + Method to iterate through a structured dictionary of variables/expressions, values, |
| 524 | + and indices to determine whether the variables have their expected values within the |
| 525 | + specified tolerances. The variables that do not have their expected values are |
| 526 | + collected and displayed for the user, and an AssertionError is raised. |
| 527 | +
|
| 528 | + This method is better than just writing a long series of assert statements because |
| 529 | + it shows *all* variables/expressions that do not have the expected values in a single |
| 530 | + report, rather than having to change the value in one assert statement, run the test |
| 531 | + again, change the next value, run the test again, etc. |
| 532 | +
|
| 533 | + This function was partially generated by AI. |
| 534 | +
|
| 535 | + Args: |
| 536 | + blk: Pyomo block on which variables/expressions being tested are located |
| 537 | + expected_results: Dictionary of the form: |
| 538 | + { |
| 539 | + indexed_var_name: { |
| 540 | + index_1: (value, rel_tol, abs_tol), |
| 541 | + index_2: (value, rel_tol, abs_tol), |
| 542 | + ... |
| 543 | + } |
| 544 | + unindexed_var_name: { |
| 545 | + # Unindexed vars pass None as the index |
| 546 | + None: (value, rel_tol, abs_tol) |
| 547 | + } |
| 548 | + ... |
| 549 | + } |
| 550 | + """ |
| 551 | + |
| 552 | + n_failures = 0 |
| 553 | + failures = [] |
| 554 | + |
| 555 | + for name, expected_values_dict in expected_results.items(): |
| 556 | + recorded_var = False |
| 557 | + obj = blk.find_component(name) |
| 558 | + if obj is None: |
| 559 | + blk_name = blk.name |
| 560 | + # Pyomo ConcreteModels are named "unknown" by default |
| 561 | + # but seeing "unknown" show up in an error message is confusing. |
| 562 | + if blk_name == "unknown": |
| 563 | + blk_name = "model" |
| 564 | + failure_msg = f" - Could not find object {name} on {blk_name}\n" |
| 565 | + failures.append(failure_msg) |
| 566 | + continue |
| 567 | + |
| 568 | + obj_type = None |
| 569 | + if isinstance(obj, Var): |
| 570 | + obj_type = "Variable" |
| 571 | + elif isinstance(obj, Expression): |
| 572 | + obj_type = "Expression" |
| 573 | + else: |
| 574 | + failure_msg = f" - Error: object {name} is not a Var or Expression\n" |
| 575 | + failures.append(failure_msg) |
| 576 | + continue |
| 577 | + |
| 578 | + for index, (expected_value, rel, abs) in expected_values_dict.items(): |
| 579 | + absent_index = False |
| 580 | + is_close = False |
| 581 | + if index is None: |
| 582 | + component_to_test = obj |
| 583 | + else: |
| 584 | + if index in obj: |
| 585 | + component_to_test = obj[index] |
| 586 | + else: |
| 587 | + absent_index = True |
| 588 | + if not absent_index: |
| 589 | + actual_value = value(component_to_test) |
| 590 | + |
| 591 | + # Determine if the values are approximately equal |
| 592 | + if actual_value == pytest.approx(expected_value, rel=rel, abs=abs): |
| 593 | + is_close = True |
| 594 | + if (absent_index or not is_close) and not recorded_var: |
| 595 | + failures.append(f" - {obj_type}: {name}") |
| 596 | + recorded_var = True |
| 597 | + if absent_index: |
| 598 | + failure_msg = f" Index: {index} is absent" |
| 599 | + failures.append(failure_msg) |
| 600 | + n_failures += 1 |
| 601 | + continue |
| 602 | + |
| 603 | + # If the comparison fails, record the details |
| 604 | + if not is_close: |
| 605 | + if rel is not None: |
| 606 | + n_sig_figs = ceil(-log10(rel)) + 1 |
| 607 | + format_spec = "." + str(n_sig_figs) + "e" |
| 608 | + elif abs is not None: |
| 609 | + n_sig_figs = ceil(-log10(abs)) + 1 |
| 610 | + format_spec = "." + str(n_sig_figs) + "f" |
| 611 | + else: |
| 612 | + format_spec = ".7e" |
| 613 | + failure_msg = ( |
| 614 | + f" Index: {index}\n" |
| 615 | + f" Expected: {expected_value:{format_spec}}\n" |
| 616 | + f" Actual: {actual_value:{format_spec}}" |
| 617 | + ) |
| 618 | + failures.append(failure_msg) |
| 619 | + n_failures += 1 |
| 620 | + |
| 621 | + if recorded_var: |
| 622 | + # Extra space between variables |
| 623 | + failures[-1] = failures[-1] + "\n" |
| 624 | + |
| 625 | + # --- Final Assertion and Report Generation --- |
| 626 | + if len(failures) > 0: |
| 627 | + # Construct the final report header |
| 628 | + report_header = textwrap.dedent(f""" |
| 629 | + =========================== Test Value Mismatches ============================ |
| 630 | + Found {n_failures} mismatch(es) between expected and actual model values. |
| 631 | + Please review the values below and update the test suite if necessary. |
| 632 | + ============================================================================== |
| 633 | + """) |
| 634 | + |
| 635 | + # Combine the header and all failure messages |
| 636 | + full_report = report_header + "\n\n" + "\n\n".join(failures) |
| 637 | + |
| 638 | + # Raise a single AssertionError with the complete report |
| 639 | + raise AssertionError(full_report) |
0 commit comments