Skip to content

Commit c3ad1b7

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Add ExpressionDerivedMetric for expression-based derived metrics (#4966)
Summary: Pull Request resolved: #4966 Adds `ExpressionDerivedMetric`, a `DerivedMetric` subclass that computes metric values from mathematical expressions of other metrics (e.g. `log(a) - log(b)`). Builds on the template-method pattern in D94844067: data lookup, validation, optional `relativize_inputs` normalization (with `as_percent` support), and per-arm cross-trial relativization are handled by the base class. This subclass adds expression parsing, validation, and evaluation. Key details: - Expression parsing via sympy (`sympify` / `lambdify`), consistent with `DerivedParameter` elsewhere in Ax. - Any function in Python's `math` module is supported (log, exp, sqrt, etc.). - Validates that expressions reference only declared input metrics. - Metric names with special characters (dots, slashes, colons) are supported via automatic sanitization. - `as_percent` parameter threaded through to base class for serialization. - SEM is set to NaN (non-linear error propagation not yet supported). - Registered in JSON and SQA metric registries for serialization. Reviewed By: bletham Differential Revision: D94389119 fbshipit-source-id: c4256391f1e8fb8bc2d9abca290554c19c127e74
1 parent ef1f1de commit c3ad1b7

4 files changed

Lines changed: 625 additions & 3 deletions

File tree

ax/core/derived_metric.py

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
metric data lives in a different trial) should override
3333
``_resolve_source_trial_indices``.
3434
35+
Concrete subclasses:
36+
37+
* ``ExpressionDerivedMetric`` (below) – computes values from a mathematical
38+
expression of other metrics (e.g. ``log(a) - log(b)``).
39+
3540
.. note:: **Transform compatibility.**
3641
Derived metrics are computed *before* any adapter transforms run.
3742
Transforms that modify metric values (e.g. ``Relativize``, ``Log``) will
@@ -44,7 +49,7 @@
4449
from __future__ import annotations
4550

4651
from logging import Logger
47-
from typing import Any
52+
from typing import Any, Callable, cast
4853

4954
import pandas as pd
5055
from ax.core.base_trial import BaseTrial
@@ -53,7 +58,12 @@
5358
from ax.exceptions.core import UserInputError
5459
from ax.utils.common.logger import get_logger
5560
from ax.utils.common.result import Err, Ok
61+
from ax.utils.common.string_utils import sanitize_name, unsanitize_name
5662
from pyre_extensions import none_throws
63+
from sympy import lambdify, sympify
64+
from sympy.core.expr import Expr
65+
from sympy.core.relational import Relational
66+
from sympy.core.symbol import Symbol
5767

5868

5969
logger: Logger = get_logger(__name__)
@@ -507,3 +517,192 @@ def summary_dict(self) -> dict[str, Any]:
507517
"relativize_inputs": self._relativize_inputs,
508518
"as_percent": self._as_percent,
509519
}
520+
521+
522+
class ExpressionDerivedMetric(DerivedMetric):
523+
"""A metric computed from a mathematical expression of other metrics.
524+
525+
The expression is parsed using sympy (consistent with other expression
526+
parsing in Ax) and compiled via ``lambdify`` for fast numeric evaluation.
527+
It may reference:
528+
529+
* Input metric names as variables
530+
* Mathematical operators: ``+``, ``-``, ``*``, ``/``, ``**``
531+
* Any function available in Python's ``math`` module (e.g. ``log``,
532+
``exp``, ``sqrt``, ``abs``, ``sin``, ``cos``, ``asin``, ``pow``, etc.)
533+
* Numeric constants
534+
535+
Attributes:
536+
expression_str: The mathematical expression string.
537+
538+
Example::
539+
540+
>>> log_ratio = ExpressionDerivedMetric(
541+
... name="log_ratio",
542+
... input_metric_names=["metric_a", "metric_b"],
543+
... expression_str="log(metric_a) - log(metric_b)",
544+
... )
545+
"""
546+
547+
def __init__(
548+
self,
549+
name: str,
550+
input_metric_names: list[str],
551+
expression_str: str,
552+
relativize_inputs: bool = False,
553+
as_percent: bool = True,
554+
lower_is_better: bool | None = None,
555+
properties: dict[str, Any] | None = None,
556+
) -> None:
557+
super().__init__(
558+
name=name,
559+
input_metric_names=input_metric_names,
560+
relativize_inputs=relativize_inputs,
561+
as_percent=as_percent,
562+
lower_is_better=lower_is_better,
563+
properties=properties,
564+
)
565+
self._expression_str = expression_str
566+
567+
# Parse & validate once; cache the compiled evaluator for reuse.
568+
# sanitize_name handles metric names with dots, slashes, etc.
569+
# (consistent with DerivedParameter's expression parsing).
570+
try:
571+
self._sympy_expr: Expr = sympify( # pyre-ignore[8]
572+
sanitize_name(self._expression_str),
573+
)
574+
except Exception as e:
575+
raise UserInputError(
576+
f"Invalid expression in ExpressionDerivedMetric "
577+
f"'{self.name}': {self._expression_str}. Error: {e}"
578+
) from e
579+
self._validate_expression()
580+
# _sympy_symbols are the sanitized names used by sympy/lambdify.
581+
# _symbols are the original (unsanitized) metric names used for
582+
# looking up values at evaluation time.
583+
# Cast free_symbols to set[Symbol] since Pyre stubs use Basic
584+
# pyre-fixme[16]: Pyre cannot infer that free_symbols contains Symbol
585+
free_syms = cast(set[Symbol], self._sympy_expr.free_symbols)
586+
sympy_symbols: list[str] = sorted(s.name for s in free_syms)
587+
self._symbols: list[str] = [unsanitize_name(s) for s in sympy_symbols]
588+
self._evaluator: Callable[..., float] = lambdify(
589+
sympy_symbols, self._sympy_expr, modules="math"
590+
)
591+
592+
@property
593+
def expression_str(self) -> str:
594+
"""The expression string defining the derivation."""
595+
return self._expression_str
596+
597+
# ------------------------------------------------------------------
598+
# Validation
599+
# ------------------------------------------------------------------
600+
601+
def _validate_expression(self) -> None:
602+
"""Validate that the parsed expression is a numeric expression with
603+
only declared input metrics as free symbols."""
604+
if isinstance(self._sympy_expr, Relational):
605+
raise UserInputError(
606+
"Comparison operators are not allowed in "
607+
"ExpressionDerivedMetric expressions. "
608+
"Use outcome constraints for comparisons."
609+
)
610+
611+
# Reject undeclared variable names.
612+
# Cast free_symbols to set[Symbol] since Pyre stubs use Basic
613+
# pyre-fixme[16]: Pyre cannot infer that free_symbols contains Symbol
614+
free_syms = cast(set[Symbol], self._sympy_expr.free_symbols)
615+
referenced_names = {unsanitize_name(s.name) for s in free_syms}
616+
input_metric_set = set(self._input_metric_names)
617+
618+
unknown_names = referenced_names - input_metric_set
619+
if unknown_names:
620+
raise UserInputError(
621+
f"Expression for ExpressionDerivedMetric '{self.name}' references "
622+
f"unknown names: {unknown_names}. Allowed metric names: "
623+
f"{input_metric_set}."
624+
)
625+
626+
unused_inputs = input_metric_set - referenced_names
627+
if unused_inputs:
628+
logger.warning(
629+
f"ExpressionDerivedMetric '{self.name}' declares input metrics "
630+
f"that are not used in the expression: {unused_inputs}."
631+
)
632+
633+
# ------------------------------------------------------------------
634+
# Evaluation
635+
# ------------------------------------------------------------------
636+
637+
def _evaluate_expression(self, metric_values: dict[str, float]) -> float:
638+
"""Evaluate the expression with the given metric values."""
639+
args = [metric_values[s] for s in self._symbols]
640+
return float(self._evaluator(*args))
641+
642+
# ------------------------------------------------------------------
643+
# Core computation (subclass hook)
644+
# ------------------------------------------------------------------
645+
646+
def _compute_derived_values(
647+
self,
648+
trial: BaseTrial,
649+
arm_data: dict[str, pd.DataFrame],
650+
) -> MetricFetchResult:
651+
"""Evaluate the expression for each arm using pre-collected data.
652+
653+
When ``relativize_inputs`` is ``True``, the base class has already
654+
relativized the ``mean`` values and excluded the status quo arm.
655+
"""
656+
result_rows: list[dict[str, Any]] = []
657+
658+
for arm_name, arm_df in arm_data.items():
659+
try:
660+
metric_values = self._extract_means(arm_df)
661+
derived_value = self._evaluate_expression(metric_values)
662+
except Exception as e:
663+
return Err(
664+
MetricFetchE(
665+
message=(
666+
f"Error evaluating ExpressionDerivedMetric "
667+
f"'{self.name}' for arm '{arm_name}' "
668+
f"in trial {trial.index}: {e}"
669+
),
670+
exception=e,
671+
)
672+
)
673+
674+
result_rows.append(
675+
{
676+
"trial_index": trial.index,
677+
"arm_name": arm_name,
678+
"metric_name": self.name,
679+
"metric_signature": self.signature,
680+
"mean": derived_value,
681+
"sem": float("nan"),
682+
}
683+
)
684+
685+
return Ok(value=Data(df=pd.DataFrame(result_rows)))
686+
687+
# ------------------------------------------------------------------
688+
# Misc
689+
# ------------------------------------------------------------------
690+
691+
def __repr__(self) -> str:
692+
parts = [
693+
f"name='{self.name}'",
694+
f"expression='{self._expression_str}'",
695+
]
696+
if self._relativize_inputs:
697+
parts.append("relativize_inputs=True")
698+
return f"ExpressionDerivedMetric({', '.join(parts)})"
699+
700+
@property
701+
def summary_dict(self) -> dict[str, Any]:
702+
"""Fields of this metric's configuration that will appear
703+
in the ``Summary`` analysis table.
704+
"""
705+
return {
706+
**super().summary_dict,
707+
"expression_str": self._expression_str,
708+
}

0 commit comments

Comments
 (0)