|
11 | 11 |
|
12 | 12 | from logging import Logger
|
13 | 13 | from numbers import Number
|
14 |
| -from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple |
| 14 | +from typing import ( |
| 15 | + Any, |
| 16 | + Callable, |
| 17 | + cast, |
| 18 | + Dict, |
| 19 | + Iterable, |
| 20 | + List, |
| 21 | + Mapping, |
| 22 | + NamedTuple, |
| 23 | + Optional, |
| 24 | + Set, |
| 25 | + Tuple, |
| 26 | +) |
15 | 27 |
|
16 | 28 | import numpy as np
|
17 | 29 | from ax.core.observation import Observation, ObservationData
|
18 | 30 | from ax.core.optimization_config import OptimizationConfig
|
19 | 31 | from ax.modelbridge.base import ModelBridge
|
20 | 32 | from ax.utils.common.logger import get_logger
|
21 |
| -from scipy.stats import fisher_exact, norm, pearsonr, spearmanr |
| 33 | + |
| 34 | +from ax.utils.stats.model_fit_stats import ( |
| 35 | + _correlation_coefficient, |
| 36 | + _fisher_exact_test_p, |
| 37 | + _log_likelihood, |
| 38 | + _mape, |
| 39 | + _mean_prediction_ci, |
| 40 | + _rank_correlation, |
| 41 | + _total_raw_effect, |
| 42 | + compute_model_fit_metrics, |
| 43 | + ModelFitMetricProtocol, |
| 44 | +) |
22 | 45 |
|
23 | 46 | logger: Logger = get_logger(__name__)
|
24 | 47 |
|
@@ -225,27 +248,36 @@ def compute_diagnostics(result: List[CVResult]) -> CVDiagnostics:
|
225 | 248 | k = res.predicted.metric_names.index(metric_name)
|
226 | 249 | y_pred[metric_name].append(res.predicted.means[k])
|
227 | 250 | se_pred[metric_name].append(np.sqrt(res.predicted.covariance[k, k]))
|
| 251 | + y_obs = _arrayify_dict_values(y_obs) |
| 252 | + y_pred = _arrayify_dict_values(y_pred) |
| 253 | + se_pred = _arrayify_dict_values(se_pred) |
| 254 | + |
| 255 | + # We need to cast here since pyre infers specific types T < ModelFitMetricProtocol |
| 256 | + # for the dict values, which is type variant upon initialization, leading |
| 257 | + # diagnostic_fns to not be recognized as a Mapping[str, ModelFitMetricProtocol], |
| 258 | + # see the last tip in the Pyre docs on [9] Incompatible Variable Type: |
| 259 | + # https://staticdocs.internalfb.com/pyre/docs/errors/#9-incompatible-variable-type |
| 260 | + diagnostic_fns = cast( |
| 261 | + Mapping[str, ModelFitMetricProtocol], |
| 262 | + { |
| 263 | + MEAN_PREDICTION_CI: _mean_prediction_ci, |
| 264 | + MAPE: _mape, |
| 265 | + TOTAL_RAW_EFFECT: _total_raw_effect, |
| 266 | + CORRELATION_COEFFICIENT: _correlation_coefficient, |
| 267 | + RANK_CORRELATION: _rank_correlation, |
| 268 | + FISHER_EXACT_TEST_P: _fisher_exact_test_p, |
| 269 | + LOG_LIKELIHOOD: _log_likelihood, |
| 270 | + }, |
| 271 | + ) |
| 272 | + diagnostics = compute_model_fit_metrics( |
| 273 | + y_obs=y_obs, y_pred=y_pred, se_pred=se_pred, fit_metrics_dict=diagnostic_fns |
| 274 | + ) |
| 275 | + return diagnostics |
228 | 276 |
|
229 |
| - diagnostic_fns = { |
230 |
| - MEAN_PREDICTION_CI: _mean_prediction_ci, |
231 |
| - MAPE: _mape, |
232 |
| - TOTAL_RAW_EFFECT: _total_raw_effect, |
233 |
| - CORRELATION_COEFFICIENT: _correlation_coefficient, |
234 |
| - RANK_CORRELATION: _rank_correlation, |
235 |
| - FISHER_EXACT_TEST_P: _fisher_exact_test_p, |
236 |
| - LOG_LIKELIHOOD: _log_likelihood, |
237 |
| - } |
238 | 277 |
|
239 |
| - diagnostics: Dict[str, Dict[str, float]] = defaultdict(dict) |
240 |
| - # Get all per-metric diagnostics. |
241 |
| - for metric_name in y_obs: |
242 |
| - for name, fn in diagnostic_fns.items(): |
243 |
| - diagnostics[name][metric_name] = fn( |
244 |
| - y_obs=np.array(y_obs[metric_name]), |
245 |
| - y_pred=np.array(y_pred[metric_name]), |
246 |
| - se_pred=np.array(se_pred[metric_name]), |
247 |
| - ) |
248 |
| - return diagnostics |
| 278 | +def _arrayify_dict_values(d: Dict[str, List[float]]) -> Dict[str, np.ndarray]: |
| 279 | + """Helper to convert dictionary values to numpy arrays.""" |
| 280 | + return {k: np.array(v) for k, v in d.items()} |
249 | 281 |
|
250 | 282 |
|
251 | 283 | def assess_model_fit(
|
@@ -339,63 +371,6 @@ def _gen_train_test_split(
|
339 | 371 | yield set(arm_names[:-n_test]), set(arm_names[-n_test:])
|
340 | 372 |
|
341 | 373 |
|
342 |
| -def _mean_prediction_ci( |
343 |
| - y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray |
344 |
| -) -> float: |
345 |
| - # Pyre does not allow float * np.ndarray. |
346 |
| - return float(np.mean(1.96 * 2 * se_pred / np.abs(y_obs))) |
347 |
| - |
348 |
| - |
349 |
| -def _log_likelihood( |
350 |
| - y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray |
351 |
| -) -> float: |
352 |
| - return float(np.sum(norm.logpdf(y_obs, loc=y_pred, scale=se_pred))) |
353 |
| - |
354 |
| - |
355 |
| -def _mape(y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray) -> float: |
356 |
| - return float(np.mean(np.abs((y_pred - y_obs) / y_obs))) |
357 |
| - |
358 |
| - |
359 |
| -def _total_raw_effect( |
360 |
| - y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray |
361 |
| -) -> float: |
362 |
| - min_y_obs = np.min(y_obs) |
363 |
| - return float((np.max(y_obs) - min_y_obs) / min_y_obs) |
364 |
| - |
365 |
| - |
366 |
| -def _correlation_coefficient( |
367 |
| - y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray |
368 |
| -) -> float: |
369 |
| - with np.errstate(invalid="ignore"): |
370 |
| - rho, _ = pearsonr(y_pred, y_obs) |
371 |
| - return float(rho) |
372 |
| - |
373 |
| - |
374 |
| -def _rank_correlation( |
375 |
| - y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray |
376 |
| -) -> float: |
377 |
| - with np.errstate(invalid="ignore"): |
378 |
| - rho, _ = spearmanr(y_pred, y_obs) |
379 |
| - return float(rho) |
380 |
| - |
381 |
| - |
382 |
| -def _fisher_exact_test_p( |
383 |
| - y_obs: np.ndarray, y_pred: np.ndarray, se_pred: np.ndarray |
384 |
| -) -> float: |
385 |
| - n_half = len(y_obs) // 2 |
386 |
| - top_obs = y_obs.argsort(axis=0)[-n_half:] |
387 |
| - top_est = y_pred.argsort(axis=0)[-n_half:] |
388 |
| - # Construct contingency table |
389 |
| - tp = len(set(top_est).intersection(top_obs)) |
390 |
| - fp = n_half - tp |
391 |
| - fn = n_half - tp |
392 |
| - tn = (len(y_obs) - n_half) - (n_half - tp) |
393 |
| - table = np.array([[tp, fp], [fn, tn]]) |
394 |
| - # Compute the test statistic |
395 |
| - _, p = fisher_exact(table, alternative="greater") |
396 |
| - return float(p) |
397 |
| - |
398 |
| - |
399 | 374 | class BestModelSelector(ABC):
|
400 | 375 | @abstractmethod
|
401 | 376 | def best_diagnostic(self, diagnostics: List[CVDiagnostics]) -> int:
|
|
0 commit comments