|
32 | 32 | metric data lives in a different trial) should override |
33 | 33 | ``_resolve_source_trial_indices``. |
34 | 34 |
|
| 35 | +Concrete subclasses: |
| 36 | +
|
| 37 | +* ``ExpressionDerivedMetric`` (below) – computes values from a mathematical |
| 38 | + expression of other metrics (e.g. ``log(a) - log(b)``). |
| 39 | +
|
35 | 40 | .. note:: **Transform compatibility.** |
36 | 41 | Derived metrics are computed *before* any adapter transforms run. |
37 | 42 | Transforms that modify metric values (e.g. ``Relativize``, ``Log``) will |
|
44 | 49 | from __future__ import annotations |
45 | 50 |
|
46 | 51 | from logging import Logger |
47 | | -from typing import Any |
| 52 | +from typing import Any, Callable, cast |
48 | 53 |
|
49 | 54 | import pandas as pd |
50 | 55 | from ax.core.base_trial import BaseTrial |
|
53 | 58 | from ax.exceptions.core import UserInputError |
54 | 59 | from ax.utils.common.logger import get_logger |
55 | 60 | from ax.utils.common.result import Err, Ok |
| 61 | +from ax.utils.common.string_utils import sanitize_name, unsanitize_name |
56 | 62 | from pyre_extensions import none_throws |
| 63 | +from sympy import lambdify, sympify |
| 64 | +from sympy.core.expr import Expr |
| 65 | +from sympy.core.relational import Relational |
| 66 | +from sympy.core.symbol import Symbol |
57 | 67 |
|
58 | 68 |
|
59 | 69 | logger: Logger = get_logger(__name__) |
@@ -507,3 +517,192 @@ def summary_dict(self) -> dict[str, Any]: |
507 | 517 | "relativize_inputs": self._relativize_inputs, |
508 | 518 | "as_percent": self._as_percent, |
509 | 519 | } |
| 520 | + |
| 521 | + |
| 522 | +class ExpressionDerivedMetric(DerivedMetric): |
| 523 | + """A metric computed from a mathematical expression of other metrics. |
| 524 | +
|
| 525 | + The expression is parsed using sympy (consistent with other expression |
| 526 | + parsing in Ax) and compiled via ``lambdify`` for fast numeric evaluation. |
| 527 | + It may reference: |
| 528 | +
|
| 529 | + * Input metric names as variables |
| 530 | + * Mathematical operators: ``+``, ``-``, ``*``, ``/``, ``**`` |
| 531 | + * Any function available in Python's ``math`` module (e.g. ``log``, |
| 532 | + ``exp``, ``sqrt``, ``abs``, ``sin``, ``cos``, ``asin``, ``pow``, etc.) |
| 533 | + * Numeric constants |
| 534 | +
|
| 535 | + Attributes: |
| 536 | + expression_str: The mathematical expression string. |
| 537 | +
|
| 538 | + Example:: |
| 539 | +
|
| 540 | + >>> log_ratio = ExpressionDerivedMetric( |
| 541 | + ... name="log_ratio", |
| 542 | + ... input_metric_names=["metric_a", "metric_b"], |
| 543 | + ... expression_str="log(metric_a) - log(metric_b)", |
| 544 | + ... ) |
| 545 | + """ |
| 546 | + |
| 547 | + def __init__( |
| 548 | + self, |
| 549 | + name: str, |
| 550 | + input_metric_names: list[str], |
| 551 | + expression_str: str, |
| 552 | + relativize_inputs: bool = False, |
| 553 | + as_percent: bool = True, |
| 554 | + lower_is_better: bool | None = None, |
| 555 | + properties: dict[str, Any] | None = None, |
| 556 | + ) -> None: |
| 557 | + super().__init__( |
| 558 | + name=name, |
| 559 | + input_metric_names=input_metric_names, |
| 560 | + relativize_inputs=relativize_inputs, |
| 561 | + as_percent=as_percent, |
| 562 | + lower_is_better=lower_is_better, |
| 563 | + properties=properties, |
| 564 | + ) |
| 565 | + self._expression_str = expression_str |
| 566 | + |
| 567 | + # Parse & validate once; cache the compiled evaluator for reuse. |
| 568 | + # sanitize_name handles metric names with dots, slashes, etc. |
| 569 | + # (consistent with DerivedParameter's expression parsing). |
| 570 | + try: |
| 571 | + self._sympy_expr: Expr = sympify( # pyre-ignore[8] |
| 572 | + sanitize_name(self._expression_str), |
| 573 | + ) |
| 574 | + except Exception as e: |
| 575 | + raise UserInputError( |
| 576 | + f"Invalid expression in ExpressionDerivedMetric " |
| 577 | + f"'{self.name}': {self._expression_str}. Error: {e}" |
| 578 | + ) from e |
| 579 | + self._validate_expression() |
| 580 | + # _sympy_symbols are the sanitized names used by sympy/lambdify. |
| 581 | + # _symbols are the original (unsanitized) metric names used for |
| 582 | + # looking up values at evaluation time. |
| 583 | + # Cast free_symbols to set[Symbol] since Pyre stubs use Basic |
| 584 | + # pyre-fixme[16]: Pyre cannot infer that free_symbols contains Symbol |
| 585 | + free_syms = cast(set[Symbol], self._sympy_expr.free_symbols) |
| 586 | + sympy_symbols: list[str] = sorted(s.name for s in free_syms) |
| 587 | + self._symbols: list[str] = [unsanitize_name(s) for s in sympy_symbols] |
| 588 | + self._evaluator: Callable[..., float] = lambdify( |
| 589 | + sympy_symbols, self._sympy_expr, modules="math" |
| 590 | + ) |
| 591 | + |
| 592 | + @property |
| 593 | + def expression_str(self) -> str: |
| 594 | + """The expression string defining the derivation.""" |
| 595 | + return self._expression_str |
| 596 | + |
| 597 | + # ------------------------------------------------------------------ |
| 598 | + # Validation |
| 599 | + # ------------------------------------------------------------------ |
| 600 | + |
| 601 | + def _validate_expression(self) -> None: |
| 602 | + """Validate that the parsed expression is a numeric expression with |
| 603 | + only declared input metrics as free symbols.""" |
| 604 | + if isinstance(self._sympy_expr, Relational): |
| 605 | + raise UserInputError( |
| 606 | + "Comparison operators are not allowed in " |
| 607 | + "ExpressionDerivedMetric expressions. " |
| 608 | + "Use outcome constraints for comparisons." |
| 609 | + ) |
| 610 | + |
| 611 | + # Reject undeclared variable names. |
| 612 | + # Cast free_symbols to set[Symbol] since Pyre stubs use Basic |
| 613 | + # pyre-fixme[16]: Pyre cannot infer that free_symbols contains Symbol |
| 614 | + free_syms = cast(set[Symbol], self._sympy_expr.free_symbols) |
| 615 | + referenced_names = {unsanitize_name(s.name) for s in free_syms} |
| 616 | + input_metric_set = set(self._input_metric_names) |
| 617 | + |
| 618 | + unknown_names = referenced_names - input_metric_set |
| 619 | + if unknown_names: |
| 620 | + raise UserInputError( |
| 621 | + f"Expression for ExpressionDerivedMetric '{self.name}' references " |
| 622 | + f"unknown names: {unknown_names}. Allowed metric names: " |
| 623 | + f"{input_metric_set}." |
| 624 | + ) |
| 625 | + |
| 626 | + unused_inputs = input_metric_set - referenced_names |
| 627 | + if unused_inputs: |
| 628 | + logger.warning( |
| 629 | + f"ExpressionDerivedMetric '{self.name}' declares input metrics " |
| 630 | + f"that are not used in the expression: {unused_inputs}." |
| 631 | + ) |
| 632 | + |
| 633 | + # ------------------------------------------------------------------ |
| 634 | + # Evaluation |
| 635 | + # ------------------------------------------------------------------ |
| 636 | + |
| 637 | + def _evaluate_expression(self, metric_values: dict[str, float]) -> float: |
| 638 | + """Evaluate the expression with the given metric values.""" |
| 639 | + args = [metric_values[s] for s in self._symbols] |
| 640 | + return float(self._evaluator(*args)) |
| 641 | + |
| 642 | + # ------------------------------------------------------------------ |
| 643 | + # Core computation (subclass hook) |
| 644 | + # ------------------------------------------------------------------ |
| 645 | + |
| 646 | + def _compute_derived_values( |
| 647 | + self, |
| 648 | + trial: BaseTrial, |
| 649 | + arm_data: dict[str, pd.DataFrame], |
| 650 | + ) -> MetricFetchResult: |
| 651 | + """Evaluate the expression for each arm using pre-collected data. |
| 652 | +
|
| 653 | + When ``relativize_inputs`` is ``True``, the base class has already |
| 654 | + relativized the ``mean`` values and excluded the status quo arm. |
| 655 | + """ |
| 656 | + result_rows: list[dict[str, Any]] = [] |
| 657 | + |
| 658 | + for arm_name, arm_df in arm_data.items(): |
| 659 | + try: |
| 660 | + metric_values = self._extract_means(arm_df) |
| 661 | + derived_value = self._evaluate_expression(metric_values) |
| 662 | + except Exception as e: |
| 663 | + return Err( |
| 664 | + MetricFetchE( |
| 665 | + message=( |
| 666 | + f"Error evaluating ExpressionDerivedMetric " |
| 667 | + f"'{self.name}' for arm '{arm_name}' " |
| 668 | + f"in trial {trial.index}: {e}" |
| 669 | + ), |
| 670 | + exception=e, |
| 671 | + ) |
| 672 | + ) |
| 673 | + |
| 674 | + result_rows.append( |
| 675 | + { |
| 676 | + "trial_index": trial.index, |
| 677 | + "arm_name": arm_name, |
| 678 | + "metric_name": self.name, |
| 679 | + "metric_signature": self.signature, |
| 680 | + "mean": derived_value, |
| 681 | + "sem": float("nan"), |
| 682 | + } |
| 683 | + ) |
| 684 | + |
| 685 | + return Ok(value=Data(df=pd.DataFrame(result_rows))) |
| 686 | + |
| 687 | + # ------------------------------------------------------------------ |
| 688 | + # Misc |
| 689 | + # ------------------------------------------------------------------ |
| 690 | + |
| 691 | + def __repr__(self) -> str: |
| 692 | + parts = [ |
| 693 | + f"name='{self.name}'", |
| 694 | + f"expression='{self._expression_str}'", |
| 695 | + ] |
| 696 | + if self._relativize_inputs: |
| 697 | + parts.append("relativize_inputs=True") |
| 698 | + return f"ExpressionDerivedMetric({', '.join(parts)})" |
| 699 | + |
| 700 | + @property |
| 701 | + def summary_dict(self) -> dict[str, Any]: |
| 702 | + """Fields of this metric's configuration that will appear |
| 703 | + in the ``Summary`` analysis table. |
| 704 | + """ |
| 705 | + return { |
| 706 | + **super().summary_dict, |
| 707 | + "expression_str": self._expression_str, |
| 708 | + } |
0 commit comments