Skip to content

Commit 9f16dcf

Browse files
lukmazThe Meridian Authors
authored andcommitted
Encapsulate GoF metrics into a dataclass
PiperOrigin-RevId: 864447565
1 parent 98cd8d9 commit 9f16dcf

File tree

4 files changed

+185
-216
lines changed

4 files changed

+185
-216
lines changed

meridian/analysis/review/checks.py

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _set_metrics_from_gof_dataframe(
186186
metrics: MutableMapping[str, float],
187187
gof_df: pd.DataFrame,
188188
geo_granularity: str,
189-
suffix: str | None = None,
189+
suffix: str,
190190
) -> None:
191191
"""Sets the `metrics` variable of the GoodnessOfFitCheckResult.
192192
@@ -200,31 +200,23 @@ def _set_metrics_from_gof_dataframe(
200200
holdout set is not used) of filtered to a single evaluation set ("all",
201201
"train", or "test").
202202
geo_granularity: The geo granularity of the data ("geo" or "national").
203-
suffix: A suffix to add to the metric names (e.g., "all", "train", "test").
204-
If None, the metrics are added without a suffix.
203+
suffix: A suffix to add to the metric names (e.g., "_train", "_test").
205204
"""
206205
gof_metrics_pivoted = gof_df.pivot(
207206
index=constants.GEO_GRANULARITY,
208207
columns=constants.METRIC,
209208
values=constants.VALUE,
210209
)
211210
gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
212-
if suffix is not None:
213-
metrics[f"{review_constants.R_SQUARED}_{suffix}"] = gof_metrics_series[
214-
constants.R_SQUARED
215-
]
216-
metrics[f"{review_constants.MAPE}_{suffix}"] = gof_metrics_series[
217-
constants.MAPE
218-
]
219-
metrics[f"{review_constants.WMAPE}_{suffix}"] = gof_metrics_series[
220-
constants.WMAPE
221-
]
222-
else:
223-
metrics[review_constants.R_SQUARED] = gof_metrics_series[
224-
constants.R_SQUARED
225-
]
226-
metrics[review_constants.MAPE] = gof_metrics_series[constants.MAPE]
227-
metrics[review_constants.WMAPE] = gof_metrics_series[constants.WMAPE]
211+
metrics[f"{review_constants.R_SQUARED}{suffix}"] = gof_metrics_series[
212+
constants.R_SQUARED
213+
]
214+
metrics[f"{review_constants.MAPE}{suffix}"] = gof_metrics_series[
215+
constants.MAPE
216+
]
217+
metrics[f"{review_constants.WMAPE}{suffix}"] = gof_metrics_series[
218+
constants.WMAPE
219+
]
228220

229221

230222
class GoodnessOfFitCheck(
@@ -243,7 +235,7 @@ def run(self) -> results.GoodnessOfFitCheckResult:
243235
gof_metrics = gof_df[gof_df[constants.GEO_GRANULARITY] == geo_granularity]
244236
is_holdout = constants.EVALUATION_SET_VAR in gof_df.columns
245237

246-
metrics = {}
238+
metrics_dict = {}
247239
case = results.GoodnessOfFitCases.PASS
248240

249241
if is_holdout:
@@ -256,28 +248,70 @@ def run(self) -> results.GoodnessOfFitCheckResult:
256248
gof_metrics[constants.EVALUATION_SET_VAR] == evaluation_set
257249
]
258250
_set_metrics_from_gof_dataframe(
259-
metrics=metrics,
251+
metrics=metrics_dict,
260252
gof_df=set_metrics,
261253
geo_granularity=geo_granularity,
262254
suffix=suffix,
263255
)
264-
if metrics[f"{review_constants.R_SQUARED}_{suffix}"] <= 0:
256+
if metrics_dict[f"{review_constants.R_SQUARED}{suffix}"] <= 0:
265257
case = results.GoodnessOfFitCases.REVIEW
258+
return results.GoodnessOfFitCheckResult(
259+
case=case,
260+
metrics=results.GoodnessOfFitMetrics(
261+
r_squared=metrics_dict[
262+
f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
263+
],
264+
mape=metrics_dict[
265+
f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
266+
],
267+
wmape=metrics_dict[
268+
f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
269+
],
270+
r_squared_train=metrics_dict[
271+
f"{review_constants.R_SQUARED}{review_constants.TRAIN_SUFFIX}"
272+
],
273+
mape_train=metrics_dict[
274+
f"{review_constants.MAPE}{review_constants.TRAIN_SUFFIX}"
275+
],
276+
wmape_train=metrics_dict[
277+
f"{review_constants.WMAPE}{review_constants.TRAIN_SUFFIX}"
278+
],
279+
r_squared_test=metrics_dict[
280+
f"{review_constants.R_SQUARED}{review_constants.TEST_SUFFIX}"
281+
],
282+
mape_test=metrics_dict[
283+
f"{review_constants.MAPE}{review_constants.TEST_SUFFIX}"
284+
],
285+
wmape_test=metrics_dict[
286+
f"{review_constants.WMAPE}{review_constants.TEST_SUFFIX}"
287+
],
288+
),
289+
is_holdout=is_holdout,
290+
)
266291
else:
267292
_set_metrics_from_gof_dataframe(
268-
metrics=metrics,
293+
metrics=metrics_dict,
269294
gof_df=gof_metrics,
270295
geo_granularity=geo_granularity,
271-
suffix=None,
296+
suffix=review_constants.ALL_SUFFIX,
272297
)
273-
if metrics[review_constants.R_SQUARED] <= 0:
298+
if metrics_dict[review_constants.R_SQUARED] <= 0:
274299
case = results.GoodnessOfFitCases.REVIEW
275-
276-
return results.GoodnessOfFitCheckResult(
277-
case=case,
278-
metrics=metrics,
279-
is_holdout=is_holdout,
280-
)
300+
return results.GoodnessOfFitCheckResult(
301+
case=case,
302+
metrics=results.GoodnessOfFitMetrics(
303+
r_squared=metrics_dict[
304+
f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
305+
],
306+
mape=metrics_dict[
307+
f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
308+
],
309+
wmape=metrics_dict[
310+
f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
311+
],
312+
),
313+
is_holdout=is_holdout,
314+
)
281315

282316

283317
# ==============================================================================

meridian/analysis/review/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
R_SQUARED = "r_squared"
3333
MAPE = "mape"
3434
WMAPE = "wmape"
35-
ALL_SUFFIX = "all"
36-
TRAIN_SUFFIX = "train"
37-
TEST_SUFFIX = "test"
35+
ALL_SUFFIX = ""
36+
TRAIN_SUFFIX = "_train"
37+
TEST_SUFFIX = "_test"
3838
EVALUATION_SET_SUFFIXES = (ALL_SUFFIX, TRAIN_SUFFIX, TEST_SUFFIX)
3939
MEAN = "mean"
4040
VARIANCE = "variance"

meridian/analysis/review/results.py

Lines changed: 47 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -345,157 +345,77 @@ def __init__(
345345
super().__init__(status, message_template, recommendation)
346346

347347

348+
@dataclasses.dataclass(frozen=True)
349+
class GoodnessOfFitMetrics:
350+
"""The metrics for the Goodness of Fit Check."""
351+
352+
r_squared: float
353+
mape: float
354+
wmape: float
355+
r_squared_train: float | None = None
356+
mape_train: float | None = None
357+
wmape_train: float | None = None
358+
r_squared_test: float | None = None
359+
mape_test: float | None = None
360+
wmape_test: float | None = None
361+
362+
348363
@dataclasses.dataclass(frozen=True)
349364
class GoodnessOfFitCheckResult(CheckResult):
350365
"""The immutable result of the Goodness of Fit Check."""
351366

352367
case: GoodnessOfFitCases
353-
metrics: Mapping[str, float]
368+
metrics: GoodnessOfFitMetrics
354369
is_holdout: bool = False
355370

356371
def __post_init__(self):
357372
if self.is_holdout:
358-
required_keys = []
359-
for suffix in [
360-
constants.ALL_SUFFIX,
361-
constants.TRAIN_SUFFIX,
362-
constants.TEST_SUFFIX,
363-
]:
364-
required_keys.extend([
365-
f"{constants.R_SQUARED}_{suffix}",
366-
f"{constants.MAPE}_{suffix}",
367-
f"{constants.WMAPE}_{suffix}",
368-
])
369-
if any(key not in self.metrics for key in required_keys):
373+
if any(
374+
metric is None
375+
for metric in (
376+
self.metrics.r_squared_train,
377+
self.metrics.mape_train,
378+
self.metrics.wmape_train,
379+
self.metrics.r_squared_test,
380+
self.metrics.mape_test,
381+
self.metrics.wmape_test,
382+
)
383+
):
370384
raise ValueError(
371385
"The message template is missing required formatting arguments for"
372-
f" holdout case. Required keys: {required_keys}. Metrics:"
386+
" holdout case. Required keys: r_squared_train, mape_train,"
387+
" wmape_train, r_squared_test, mape_test, wmape_test. Metrics:"
373388
f" {self.metrics}."
374389
)
375-
elif any(
376-
key not in self.metrics
377-
for key in (
378-
constants.R_SQUARED,
379-
constants.MAPE,
380-
constants.WMAPE,
381-
)
382-
):
383-
raise ValueError(
384-
"The message template is missing required formatting arguments:"
385-
" r_squared, mape, wmape. Metrics:"
386-
f" {self.metrics}."
387-
)
388-
389-
@property
390-
def r_squared(self) -> float | None:
391-
"""The R-squared metric."""
392-
return self.metrics[constants.R_SQUARED] if not self.is_holdout else None
393-
394-
@property
395-
def mape(self) -> float | None:
396-
"""The MAPE metric."""
397-
return self.metrics[constants.MAPE] if not self.is_holdout else None
398-
399-
@property
400-
def wmape(self) -> float | None:
401-
"""The wMAPE metric."""
402-
return self.metrics[constants.WMAPE] if not self.is_holdout else None
403-
404-
@property
405-
def r_squared_all(self) -> float | None:
406-
"""The R-squared metric for all data."""
407-
return (
408-
self.metrics[f"{constants.R_SQUARED}_{constants.ALL_SUFFIX}"]
409-
if self.is_holdout
410-
else None
411-
)
412-
413-
@property
414-
def mape_all(self) -> float | None:
415-
"""The MAPE metric for all data."""
416-
return (
417-
self.metrics[f"{constants.MAPE}_{constants.ALL_SUFFIX}"]
418-
if self.is_holdout
419-
else None
420-
)
421-
422-
@property
423-
def wmape_all(self) -> float | None:
424-
"""The wMAPE metric for all data."""
425-
return (
426-
self.metrics[f"{constants.WMAPE}_{constants.ALL_SUFFIX}"]
427-
if self.is_holdout
428-
else None
429-
)
430-
431-
@property
432-
def r_squared_train(self) -> float | None:
433-
"""The R-squared metric for train data."""
434-
return (
435-
self.metrics[f"{constants.R_SQUARED}_{constants.TRAIN_SUFFIX}"]
436-
if self.is_holdout
437-
else None
438-
)
439-
440-
@property
441-
def mape_train(self) -> float | None:
442-
"""The MAPE metric for train data."""
443-
return (
444-
self.metrics[f"{constants.MAPE}_{constants.TRAIN_SUFFIX}"]
445-
if self.is_holdout
446-
else None
447-
)
448-
449-
@property
450-
def wmape_train(self) -> float | None:
451-
"""The wMAPE metric for train data."""
452-
return (
453-
self.metrics[f"{constants.WMAPE}_{constants.TRAIN_SUFFIX}"]
454-
if self.is_holdout
455-
else None
456-
)
457-
458-
@property
459-
def r_squared_test(self) -> float | None:
460-
"""The R-squared metric for test data."""
461-
return (
462-
self.metrics[f"{constants.R_SQUARED}_{constants.TEST_SUFFIX}"]
463-
if self.is_holdout
464-
else None
465-
)
466-
467-
@property
468-
def mape_test(self) -> float | None:
469-
"""The MAPE metric for test data."""
470-
return (
471-
self.metrics[f"{constants.MAPE}_{constants.TEST_SUFFIX}"]
472-
if self.is_holdout
473-
else None
474-
)
475-
476-
@property
477-
def wmape_test(self) -> float | None:
478-
"""The wMAPE metric for test data."""
479-
return (
480-
self.metrics[f"{constants.WMAPE}_{constants.TEST_SUFFIX}"]
481-
if self.is_holdout
482-
else None
483-
)
484390

485391
@property
486392
def details(self) -> Mapping[str, Any]:
487393
"""The check result details."""
488-
return self.metrics
394+
return {
395+
f"{constants.R_SQUARED}{constants.ALL_SUFFIX}": self.metrics.r_squared,
396+
f"{constants.MAPE}{constants.ALL_SUFFIX}": self.metrics.mape,
397+
f"{constants.WMAPE}{constants.ALL_SUFFIX}": self.metrics.wmape,
398+
f"{constants.R_SQUARED}{constants.TRAIN_SUFFIX}": (
399+
self.metrics.r_squared_train
400+
),
401+
f"{constants.MAPE}{constants.TRAIN_SUFFIX}": self.metrics.mape_train,
402+
f"{constants.WMAPE}{constants.TRAIN_SUFFIX}": self.metrics.wmape_train,
403+
f"{constants.R_SQUARED}{constants.TEST_SUFFIX}": (
404+
self.metrics.r_squared_test
405+
),
406+
f"{constants.MAPE}{constants.TEST_SUFFIX}": self.metrics.mape_test,
407+
f"{constants.WMAPE}{constants.TEST_SUFFIX}": self.metrics.wmape_test,
408+
}
489409

490410
@property
491411
def recommendation(self) -> str:
492412
"""The check result message."""
493413
if self.is_holdout:
494414
report_str = (
495-
"R-squared = {r_squared_all:.4f} (All),"
415+
"R-squared = {r_squared:.4f} (All),"
496416
" {r_squared_train:.4f} (Train), {r_squared_test:.4f} (Test); MAPE"
497-
" = {mape_all:.4f} (All), {mape_train:.4f} (Train),"
498-
" {mape_test:.4f} (Test); wMAPE = {wmape_all:.4f} (All),"
417+
" = {mape:.4f} (All), {mape_train:.4f} (Train),"
418+
" {mape_test:.4f} (Test); wMAPE = {wmape:.4f} (All),"
499419
" {wmape_train:.4f} (Train), {wmape_test:.4f} (Test)".format(
500420
**self.details
501421
)

0 commit comments

Comments
 (0)