Skip to content

Commit e4aa110

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Model fit metrics for logging (#1682)
Summary: Pull Request resolved: #1682 This commit adds metrics in order to quantify and log model fit quality for each *experimental* metric. To start, the commit adds metrics based on the posterior statistics of the model, which can be extended readily by adding to the `fit_metrics` dict, and can be generalized with other metric types in follow up work. Reviewed By: Balandat Differential Revision: D46816506 fbshipit-source-id: 0e4f9d9d8f4030b9793bdcf9ec5c218fccb91990
1 parent 94363dd commit e4aa110

12 files changed

+671
-97
lines changed

ax/modelbridge/base.py

+107-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from dataclasses import dataclass, field
1313

1414
from logging import Logger
15-
from typing import Any, Dict, List, MutableMapping, Optional, Set, Tuple, Type
15+
from typing import Any, cast, Dict, List, MutableMapping, Optional, Set, Tuple, Type
1616

17+
import numpy as np
1718
from ax.core.arm import Arm
1819
from ax.core.data import Data
1920
from ax.core.experiment import Experiment
@@ -36,6 +37,13 @@
3637
from ax.models.types import TConfig
3738
from ax.utils.common.logger import get_logger
3839
from ax.utils.common.typeutils import checked_cast, not_none
40+
from ax.utils.stats.model_fit_stats import (
41+
coefficient_of_determination,
42+
compute_model_fit_metrics,
43+
mean_of_the_standardized_error,
44+
ModelFitMetricProtocol,
45+
std_of_the_standardized_error,
46+
)
3947
from botorch.exceptions.warnings import InputDataWarning
4048

4149
logger: Logger = get_logger(__name__)
@@ -918,6 +926,51 @@ def _cross_validate(
918926
"""
919927
raise NotImplementedError # pragma: no cover
920928

929+
def compute_model_fit_metrics(
930+
self,
931+
experiment: Experiment,
932+
fit_metrics_dict: Optional[Dict[str, ModelFitMetricProtocol]] = None,
933+
) -> Dict[str, Dict[str, float]]:
934+
"""Computes the model fit metrics from the scheduler state.
935+
936+
Args:
937+
experiment: The experiment with whose data to compute the model fit metrics.
938+
fit_metrics_dict: An optional dictionary with model fit metric functions,
939+
i.e. a ModelFitMetricProtocol, as values and their names as keys.
940+
941+
Returns:
942+
A nested dictionary mapping from the *model fit* metric names and the
943+
*experimental metric* names to the values of the model fit metrics.
944+
945+
Example for an imaginary AutoML experiment that seeks to minimize the test
946+
error after training an expensive model, with respect to hyper-parameters:
947+
948+
```
949+
model_fit_dict = model_fit_metrics_from_scheduler(scheduler)
950+
model_fit_dict["coefficient_of_determination"]["test error"] =
951+
`coefficient of determination of the test error predictions`
952+
```
953+
"""
954+
# TODO: cross_validate_by_trial-based generalization quality
955+
# IDEA: store y_obs, y_pred, se_pred as well
956+
y_obs, y_pred, se_pred = _predict_on_training_data(
957+
model_bridge=self, experiment=experiment
958+
)
959+
if fit_metrics_dict is None:
960+
fit_metrics_dict = {
961+
"coefficient_of_determination": coefficient_of_determination,
962+
"mean_of_the_standardized_error": mean_of_the_standardized_error,
963+
"std_of_the_standardized_error": std_of_the_standardized_error,
964+
}
965+
fit_metrics_dict = cast(Dict[str, ModelFitMetricProtocol], fit_metrics_dict)
966+
967+
return compute_model_fit_metrics(
968+
y_obs=y_obs,
969+
y_pred=y_pred,
970+
se_pred=se_pred,
971+
fit_metrics_dict=fit_metrics_dict,
972+
)
973+
921974
def _set_kwargs_to_save(
922975
self,
923976
model_key: str,
@@ -1099,3 +1152,56 @@ def clamp_observation_features(
10991152
)
11001153
obsf.parameters[p.name] = p.upper
11011154
return observation_features
1155+
1156+
1157+
"""
1158+
############################## Model Fit Metrics Utils ##############################
1159+
"""
1160+
1161+
1162+
def _predict_on_training_data(
1163+
model_bridge: ModelBridge,
1164+
experiment: Experiment,
1165+
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray],]:
1166+
"""Makes predictions on the training data of a given experiment using a ModelBridge
1167+
and returning the observed values, and the corresponding predictive means and
1168+
predictive standard deviations of the model.
1169+
1170+
NOTE: This is a helper function for `ModelBridge.compute_model_fit_metrics` and
1171+
could be attached to the class.
1172+
1173+
Args:
1174+
model_bridge: A ModelBridge object with which to make predictions.
1175+
experiment: The experiment with whose data to compute the model fit metrics.
1176+
1177+
Returns:
1178+
A tuple containing three dictionaries for 1) observed metric values, and the
1179+
model's associated 2) predictive means and 3) predictive standard deviations.
1180+
"""
1181+
data = experiment.fetch_data()
1182+
observations = observations_from_data(
1183+
experiment=experiment, data=data
1184+
) # List[Observation]
1185+
observation_features = [obs.features for obs in observations]
1186+
mean_predicted, cov_predicted = model_bridge.predict(
1187+
observation_features=observation_features
1188+
) # Dict[str, List[float]]
1189+
mean_observed = [
1190+
obs.data.means_dict for obs in observations
1191+
] # List[Dict[str, float]]
1192+
metric_names = list(data.metric_names)
1193+
mean_observed = _list_of_dicts_to_dict_of_lists(
1194+
list_of_dicts=mean_observed, keys=metric_names
1195+
)
1196+
# converting dictionary values to arrays
1197+
mean_observed = {k: np.array(v) for k, v in mean_observed.items()}
1198+
mean_predicted = {k: np.array(v) for k, v in mean_predicted.items()}
1199+
std_predicted = {m: np.sqrt(np.array(cov_predicted[m][m])) for m in cov_predicted}
1200+
return mean_observed, mean_predicted, std_predicted
1201+
1202+
1203+
def _list_of_dicts_to_dict_of_lists(
1204+
list_of_dicts: List[Dict[str, float]], keys: List[str]
1205+
) -> Dict[str, List[float]]:
1206+
"""Converts a list of dicts indexed by a string to a dict of lists."""
1207+
return {key: [d[key] for d in list_of_dicts] for key in keys}

ax/modelbridge/cross_validation.py

+53-78
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,37 @@
1111

1212
from logging import Logger
1313
from numbers import Number
14-
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple
14+
from typing import (
15+
Any,
16+
Callable,
17+
cast,
18+
Dict,
19+
Iterable,
20+
List,
21+
Mapping,
22+
NamedTuple,
23+
Optional,
24+
Set,
25+
Tuple,
26+
)
1527

1628
import numpy as np
1729
from ax.core.observation import Observation, ObservationData
1830
from ax.core.optimization_config import OptimizationConfig
1931
from ax.modelbridge.base import ModelBridge
2032
from ax.utils.common.logger import get_logger
21-
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
33+
34+
from ax.utils.stats.model_fit_stats import (
35+
_correlation_coefficient,
36+
_fisher_exact_test_p,
37+
_log_likelihood,
38+
_mape,
39+
_mean_prediction_ci,
40+
_rank_correlation,
41+
_total_raw_effect,
42+
compute_model_fit_metrics,
43+
ModelFitMetricProtocol,
44+
)
2245

2346
logger: Logger = get_logger(__name__)
2447

@@ -225,27 +248,36 @@ def compute_diagnostics(result: List[CVResult]) -> CVDiagnostics:
225248
k = res.predicted.metric_names.index(metric_name)
226249
y_pred[metric_name].append(res.predicted.means[k])
227250
se_pred[metric_name].append(np.sqrt(res.predicted.covariance[k, k]))
251+
y_obs = _arrayify_dict_values(y_obs)
252+
y_pred = _arrayify_dict_values(y_pred)
253+
se_pred = _arrayify_dict_values(se_pred)
254+
255+
# We need to cast here since pyre infers specific types T < ModelFitMetricProtocol
256+
# for the dict values, which is type variant upon initialization, leading
257+
# diagnostic_fns to not be recognized as a Mapping[str, ModelFitMetricProtocol],
258+
# see the last tip in the Pyre docs on [9] Incompatible Variable Type:
259+
# https://staticdocs.internalfb.com/pyre/docs/errors/#9-incompatible-variable-type
260+
diagnostic_fns = cast(
261+
Mapping[str, ModelFitMetricProtocol],
262+
{
263+
MEAN_PREDICTION_CI: _mean_prediction_ci,
264+
MAPE: _mape,
265+
TOTAL_RAW_EFFECT: _total_raw_effect,
266+
CORRELATION_COEFFICIENT: _correlation_coefficient,
267+
RANK_CORRELATION: _rank_correlation,
268+
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
269+
LOG_LIKELIHOOD: _log_likelihood,
270+
},
271+
)
272+
diagnostics = compute_model_fit_metrics(
273+
y_obs=y_obs, y_pred=y_pred, se_pred=se_pred, fit_metrics_dict=diagnostic_fns
274+
)
275+
return diagnostics
228276

229-
diagnostic_fns = {
230-
MEAN_PREDICTION_CI: _mean_prediction_ci,
231-
MAPE: _mape,
232-
TOTAL_RAW_EFFECT: _total_raw_effect,
233-
CORRELATION_COEFFICIENT: _correlation_coefficient,
234-
RANK_CORRELATION: _rank_correlation,
235-
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
236-
LOG_LIKELIHOOD: _log_likelihood,
237-
}
238277

239-
diagnostics: Dict[str, Dict[str, float]] = defaultdict(dict)
240-
# Get all per-metric diagnostics.
241-
for metric_name in y_obs:
242-
for name, fn in diagnostic_fns.items():
243-
diagnostics[name][metric_name] = fn(
244-
y_obs=np.array(y_obs[metric_name]),
245-
y_pred=np.array(y_pred[metric_name]),
246-
se_pred=np.array(se_pred[metric_name]),
247-
)
248-
return diagnostics
278+
def _arrayify_dict_values(d: Dict[str, List[float]]) -> Dict[str, np.ndarray]:
279+
"""Helper to convert dictionary values to numpy arrays."""
280+
return {k: np.array(v) for k, v in d.items()}
249281

250282

251283
def assess_model_fit(
@@ -339,63 +371,6 @@ def _gen_train_test_split(
339371
yield set(arm_names[:-n_test]), set(arm_names[-n_test:])
340372

341373

342-
def _mean_prediction_ci(
343-
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
344-
) -> float:
345-
# Pyre does not allow float * np.ndarray.
346-
return float(np.mean(1.96 * 2 * se_pred / np.abs(y_obs)))
347-
348-
349-
def _log_likelihood(
350-
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
351-
) -> float:
352-
return float(np.sum(norm.logpdf(y_obs, loc=y_pred, scale=se_pred)))
353-
354-
355-
def _mape(y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray) -> float:
356-
return float(np.mean(np.abs((y_pred - y_obs) / y_obs)))
357-
358-
359-
def _total_raw_effect(
360-
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
361-
) -> float:
362-
min_y_obs = np.min(y_obs)
363-
return float((np.max(y_obs) - min_y_obs) / min_y_obs)
364-
365-
366-
def _correlation_coefficient(
367-
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
368-
) -> float:
369-
with np.errstate(invalid="ignore"):
370-
rho, _ = pearsonr(y_pred, y_obs)
371-
return float(rho)
372-
373-
374-
def _rank_correlation(
375-
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
376-
) -> float:
377-
with np.errstate(invalid="ignore"):
378-
rho, _ = spearmanr(y_pred, y_obs)
379-
return float(rho)
380-
381-
382-
def _fisher_exact_test_p(
383-
y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray
384-
) -> float:
385-
n_half = len(y_obs) // 2
386-
top_obs = y_obs.argsort(axis=0)[-n_half:]
387-
top_est = y_pred.argsort(axis=0)[-n_half:]
388-
# Construct contingency table
389-
tp = len(set(top_est).intersection(top_obs))
390-
fp = n_half - tp
391-
fn = n_half - tp
392-
tn = (len(y_obs) - n_half) - (n_half - tp)
393-
table = np.array([[tp, fp], [fn, tn]])
394-
# Compute the test statistic
395-
_, p = fisher_exact(table, alternative="greater")
396-
return float(p)
397-
398-
399374
class BestModelSelector(ABC):
400375
@abstractmethod
401376
def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env fbpython
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast, Dict
8+
9+
from ax.core.experiment import Experiment
10+
from ax.core.objective import Objective
11+
from ax.core.optimization_config import OptimizationConfig
12+
from ax.metrics.branin import BraninMetric
13+
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
14+
from ax.modelbridge.registry import Models
15+
from ax.runners.synthetic import SyntheticRunner
16+
from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions
17+
from ax.utils.common.constants import Keys
18+
from ax.utils.common.testutils import TestCase
19+
from ax.utils.testing.core_stubs import get_branin_search_space
20+
21+
NUM_SOBOL = 5
22+
23+
24+
class TestModelBridgeFitMetrics(TestCase):
25+
def setUp(self) -> None:
26+
# setting up experiment and generation strategy
27+
self.runner = SyntheticRunner()
28+
self.branin_experiment = Experiment(
29+
name="branin_test_experiment",
30+
search_space=get_branin_search_space(),
31+
runner=self.runner,
32+
optimization_config=OptimizationConfig(
33+
objective=Objective(
34+
metric=BraninMetric(name="branin", param_names=["x1", "x2"]),
35+
minimize=True,
36+
),
37+
),
38+
is_test=True,
39+
)
40+
self.branin_experiment._properties[
41+
Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF
42+
] = True
43+
self.generation_strategy = GenerationStrategy(
44+
steps=[
45+
GenerationStep(
46+
model=Models.SOBOL, num_trials=NUM_SOBOL, max_parallelism=NUM_SOBOL
47+
),
48+
GenerationStep(model=Models.GPEI, num_trials=-1),
49+
]
50+
)
51+
52+
def test_model_fit_metrics(self) -> None:
53+
scheduler = Scheduler(
54+
experiment=self.branin_experiment,
55+
generation_strategy=self.generation_strategy,
56+
options=SchedulerOptions(),
57+
)
58+
# need to run some trials to initialize the ModelBridge
59+
scheduler.run_n_trials(max_trials=NUM_SOBOL + 1)
60+
model_bridge = get_fitted_model_bridge(scheduler)
61+
62+
# testing ModelBridge.compute_model_fit_metrics with default metrics
63+
fit_metrics = model_bridge.compute_model_fit_metrics(self.branin_experiment)
64+
r2 = fit_metrics.get("coefficient_of_determination")
65+
self.assertIsInstance(r2, dict)
66+
r2 = cast(Dict[str, float], r2)
67+
self.assertTrue("branin" in r2)
68+
r2_branin = r2["branin"]
69+
self.assertIsInstance(r2_branin, float)
70+
71+
std = fit_metrics.get("std_of_the_standardized_error")
72+
self.assertIsInstance(std, dict)
73+
std = cast(Dict[str, float], std)
74+
self.assertTrue("branin" in std)
75+
std_branin = std["branin"]
76+
self.assertIsInstance(std_branin, float)
77+
78+
# testing with empty metrics
79+
empty_metrics = model_bridge.compute_model_fit_metrics(
80+
self.branin_experiment, fit_metrics_dict={}
81+
)
82+
self.assertIsInstance(empty_metrics, dict)
83+
self.assertTrue(len(empty_metrics) == 0)

0 commit comments

Comments
 (0)