Skip to content

Commit e24e156

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Add DerivedMetric base class for metrics computed from other metrics (facebook#4950)
Summary: Pull Request resolved: facebook#4950 Introduces `DerivedMetric`, a base `Metric` subclass for metrics whose values depend on other metrics being fetched first. This enables a two-phase data-fetch pattern where base metrics are fetched and cached before derived metrics are computed. Key changes: 1. **`DerivedMetric` base class** (`ax/core/derived_metric.py`): Declares `input_metric_names` — names of metrics that must be available before this metric can be computed. Subclasses override `fetch_trial_data` to define the derivation logic. 2. **Two-phase fetching in `Experiment`**: `_lookup_or_fetch_trials_results` now separates base metrics from derived metrics, fetches base metrics first and attaches them to the cache, then fetches derived metrics. This ensures derived metrics can read their inputs via `trial.lookup_data()`. 3. **Test coverage**: Unit tests for the base class (init, validation, clone, summary_dict) and integration tests verifying the two-phase fetch in `Experiment`. A concrete subclass (`ExpressionDerivedMetric`) for expression-based derivation and storage registration is added in a follow-up diff. Reviewed By: lena-kashtelyan, saitcakmak Differential Revision: D92749156 fbshipit-source-id: 2e845bcbcc11798eea4de4f304e3713781ec8ba0
1 parent e6961dc commit e24e156

File tree

3 files changed

+424
-14
lines changed

3 files changed

+424
-14
lines changed

ax/core/derived_metric.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
DerivedMetric: A metric computed from other metrics.
11+
12+
``DerivedMetric`` is a base class for metrics whose values depend on other
13+
metrics being fetched first. The experiment's data-fetch loop uses
14+
``isinstance(m, DerivedMetric)`` to guarantee that all base metric data is
15+
attached to the cache before any derived metric's ``fetch_trial_data`` runs.
16+
17+
.. note:: **Transform compatibility.**
18+
Derived metrics are computed *before* any adapter transforms run.
19+
Transforms that modify metric values (e.g. ``Relativize``, ``Log``) will
20+
be applied to the already-computed derived value, **not** to its inputs
21+
individually. This means a derived metric ``log(a) - log(b)`` followed
22+
by a ``Log`` transform would double-log the result. Avoid using
23+
transforms that overlap with operations already baked into the derivation.
24+
"""
25+
26+
from __future__ import annotations
27+
28+
from logging import Logger
29+
from typing import Any
30+
31+
import pandas as pd
32+
from ax.core.metric import Metric
33+
from ax.exceptions.core import UserInputError
34+
from ax.utils.common.logger import get_logger
35+
36+
37+
logger: Logger = get_logger(__name__)
38+
39+
40+
class DerivedMetric(Metric):
41+
"""Base class for metrics that depend on other metrics.
42+
43+
A ``DerivedMetric`` declares the names of metrics whose data must be
44+
available before this metric can be computed. The experiment's two-phase
45+
fetch loop (see ``Experiment._lookup_or_fetch_trials_results``) separates
46+
derived metrics from base metrics and fetches base metrics first.
47+
48+
Subclasses must override ``fetch_trial_data`` to define how the derived
49+
value is produced.
50+
51+
Attributes:
52+
input_metric_names: Names of metrics that must be fetched first.
53+
"""
54+
55+
def __init__(
56+
self,
57+
name: str,
58+
input_metric_names: list[str],
59+
lower_is_better: bool | None = None,
60+
properties: dict[str, Any] | None = None,
61+
) -> None:
62+
if not input_metric_names:
63+
raise UserInputError(
64+
f"DerivedMetric '{name}' must declare at least one input "
65+
f"metric in input_metric_names."
66+
)
67+
super().__init__(
68+
name=name,
69+
lower_is_better=lower_is_better,
70+
properties=properties,
71+
)
72+
self._input_metric_names = input_metric_names
73+
74+
@property
75+
def input_metric_names(self) -> list[str]:
76+
"""Names of metrics that this metric depends on."""
77+
return self._input_metric_names
78+
79+
@staticmethod
80+
def _lookup_metric_values_for_arm(
81+
arm_df: pd.DataFrame,
82+
metric_name: str,
83+
) -> pd.DataFrame:
84+
"""Look up rows for *metric_name* by ``metric_name`` or
85+
``metric_signature`` column."""
86+
return arm_df[
87+
(arm_df["metric_name"] == metric_name)
88+
| (arm_df["metric_signature"] == metric_name)
89+
]
90+
91+
@property
92+
def summary_dict(self) -> dict[str, Any]:
93+
"""Fields of this metric's configuration that will appear
94+
in the ``Summary`` analysis table.
95+
"""
96+
return {
97+
**super().summary_dict,
98+
"input_metric_names": self._input_metric_names,
99+
}

ax/core/experiment.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ax.core.base_trial import BaseTrial
3131
from ax.core.batch_trial import BatchTrial
3232
from ax.core.data import combine_data_rows_favoring_recent, Data
33+
from ax.core.derived_metric import DerivedMetric
3334
from ax.core.experiment_status import ExperimentStatus
3435
from ax.core.generator_run import GeneratorRun
3536
from ax.core.llm_provider import LLMMessage
@@ -1045,19 +1046,94 @@ def _lookup_or_fetch_trials_results(
10451046
logger.debug("No trials are in a state expecting data.")
10461047
return {}
10471048
metrics_to_fetch = list(metrics or self.metrics.values())
1048-
metrics_by_class = self._metrics_by_class(metrics=metrics_to_fetch)
1049+
1050+
# Separate base metrics from derived metrics.
1051+
# Derived metrics must be fetched after base metrics because they
1052+
# depend on base metric data being available in the cache.
1053+
base_metrics: list[Metric] = [
1054+
m for m in metrics_to_fetch if not isinstance(m, DerivedMetric)
1055+
]
1056+
derived_metrics: list[Metric] = [
1057+
m for m in metrics_to_fetch if isinstance(m, DerivedMetric)
1058+
]
1059+
1060+
results: dict[int, dict[str, MetricFetchResult]] = {}
1061+
contains_new_data = False
1062+
1063+
# Phase 1: Fetch all base (non-derived) metrics first.
1064+
if base_metrics:
1065+
base_results, base_new = self._fetch_metrics_by_class(
1066+
trials=trials,
1067+
metrics=base_metrics,
1068+
**kwargs,
1069+
)
1070+
results = base_results
1071+
contains_new_data = base_new
1072+
1073+
# Attach base metric results to the cache BEFORE fetching derived
1074+
# metrics so they can access base data via lookup_data().
1075+
if base_new and derived_metrics:
1076+
self._try_attach_fetch_results(base_results)
1077+
1078+
# Phase 2: Fetch derived metrics (they look up base data from cache).
1079+
if derived_metrics:
1080+
derived_results, derived_new = self._fetch_metrics_by_class(
1081+
trials=trials,
1082+
metrics=derived_metrics,
1083+
**kwargs,
1084+
)
1085+
for trial_index, trial_metrics in derived_results.items():
1086+
results.setdefault(trial_index, {}).update(trial_metrics)
1087+
contains_new_data = contains_new_data or derived_new
1088+
1089+
# Attach all results (base + derived).
1090+
if contains_new_data:
1091+
self._try_attach_fetch_results(results)
1092+
1093+
return results
1094+
1095+
def _try_attach_fetch_results(
1096+
self,
1097+
results: dict[int, dict[str, MetricFetchResult]],
1098+
) -> None:
1099+
"""Attach fetch results to the experiment cache, logging on error."""
1100+
try:
1101+
self.attach_fetch_results(results=results)
1102+
except ValueError as e:
1103+
logger.error(
1104+
f"Encountered ValueError {e} while attaching results. "
1105+
"Proceeding and returning results fetched without attaching."
1106+
)
1107+
1108+
def _fetch_metrics_by_class(
1109+
self,
1110+
trials: list[BaseTrial],
1111+
metrics: list[Metric],
1112+
**kwargs: Any,
1113+
) -> tuple[dict[int, dict[str, MetricFetchResult]], bool]:
1114+
"""Fetch metrics grouped by class.
1115+
1116+
Args:
1117+
trials: List of trials to fetch data for.
1118+
metrics: List of metrics to fetch.
1119+
**kwargs: Additional keyword arguments passed to fetch methods.
1120+
1121+
Returns:
1122+
A tuple of (results dict, contains_new_data bool).
1123+
"""
1124+
metrics_by_class = self._metrics_by_class(metrics=metrics)
10491125

10501126
results: dict[int, dict[str, MetricFetchResult]] = {}
10511127
contains_new_data = False
10521128

1053-
for metric_cls, metrics in metrics_by_class.items():
1054-
first_metric_of_group = metrics[0]
1129+
for _metric_cls, cls_metrics in metrics_by_class.items():
1130+
first_metric_of_group = cls_metrics[0]
10551131
(
10561132
new_fetch_results,
10571133
new_results_contains_new_data,
10581134
) = first_metric_of_group.fetch_data_prefer_lookup(
10591135
experiment=self,
1060-
metrics=metrics_by_class[metric_cls],
1136+
metrics=cls_metrics,
10611137
trials=trials,
10621138
**kwargs,
10631139
)
@@ -1077,16 +1153,7 @@ def _lookup_or_fetch_trials_results(
10771153
for trial in trials
10781154
}
10791155

1080-
if contains_new_data:
1081-
try:
1082-
self.attach_fetch_results(results=results)
1083-
except ValueError as e:
1084-
logger.error(
1085-
f"Encountered ValueError {e} while attaching results. Proceeding "
1086-
"and returning Results fetched without attaching."
1087-
)
1088-
1089-
return results
1156+
return results, contains_new_data
10901157

10911158
@copy_doc(BaseTrial.fetch_data)
10921159
def _fetch_trial_data(

0 commit comments

Comments
 (0)