@@ -345,157 +345,77 @@ def __init__(
345345 super ().__init__ (status , message_template , recommendation )
346346
347347
348+ @dataclasses .dataclass (frozen = True )
349+ class GoodnessOfFitMetrics :
350+ """The metrics for the Goodness of Fit Check."""
351+
352+ r_squared : float
353+ mape : float
354+ wmape : float
355+ r_squared_train : float | None = None
356+ mape_train : float | None = None
357+ wmape_train : float | None = None
358+ r_squared_test : float | None = None
359+ mape_test : float | None = None
360+ wmape_test : float | None = None
361+
362+
348363@dataclasses .dataclass (frozen = True )
349364class GoodnessOfFitCheckResult (CheckResult ):
350365 """The immutable result of the Goodness of Fit Check."""
351366
352367 case : GoodnessOfFitCases
353- metrics : Mapping [ str , float ]
368+ metrics : GoodnessOfFitMetrics
354369 is_holdout : bool = False
355370
356371 def __post_init__ (self ):
357372 if self .is_holdout :
358- required_keys = []
359- for suffix in [
360- constants .ALL_SUFFIX ,
361- constants .TRAIN_SUFFIX ,
362- constants .TEST_SUFFIX ,
363- ]:
364- required_keys .extend ([
365- f"{ constants .R_SQUARED } _{ suffix } " ,
366- f"{ constants .MAPE } _{ suffix } " ,
367- f"{ constants .WMAPE } _{ suffix } " ,
368- ])
369- if any (key not in self .metrics for key in required_keys ):
373+ if any (
374+ metric is None
375+ for metric in (
376+ self .metrics .r_squared_train ,
377+ self .metrics .mape_train ,
378+ self .metrics .wmape_train ,
379+ self .metrics .r_squared_test ,
380+ self .metrics .mape_test ,
381+ self .metrics .wmape_test ,
382+ )
383+ ):
370384 raise ValueError (
371385 "The message template is missing required formatting arguments for"
372- f" holdout case. Required keys: { required_keys } . Metrics:"
386+ " holdout case. Required keys: r_squared_train, mape_train,"
387+ " wmape_train, r_squared_test, mape_test, wmape_test. Metrics:"
373388 f" { self .metrics } ."
374389 )
375- elif any (
376- key not in self .metrics
377- for key in (
378- constants .R_SQUARED ,
379- constants .MAPE ,
380- constants .WMAPE ,
381- )
382- ):
383- raise ValueError (
384- "The message template is missing required formatting arguments:"
385- " r_squared, mape, wmape. Metrics:"
386- f" { self .metrics } ."
387- )
388-
389- @property
390- def r_squared (self ) -> float | None :
391- """The R-squared metric."""
392- return self .metrics [constants .R_SQUARED ] if not self .is_holdout else None
393-
394- @property
395- def mape (self ) -> float | None :
396- """The MAPE metric."""
397- return self .metrics [constants .MAPE ] if not self .is_holdout else None
398-
399- @property
400- def wmape (self ) -> float | None :
401- """The wMAPE metric."""
402- return self .metrics [constants .WMAPE ] if not self .is_holdout else None
403-
404- @property
405- def r_squared_all (self ) -> float | None :
406- """The R-squared metric for all data."""
407- return (
408- self .metrics [f"{ constants .R_SQUARED } _{ constants .ALL_SUFFIX } " ]
409- if self .is_holdout
410- else None
411- )
412-
413- @property
414- def mape_all (self ) -> float | None :
415- """The MAPE metric for all data."""
416- return (
417- self .metrics [f"{ constants .MAPE } _{ constants .ALL_SUFFIX } " ]
418- if self .is_holdout
419- else None
420- )
421-
422- @property
423- def wmape_all (self ) -> float | None :
424- """The wMAPE metric for all data."""
425- return (
426- self .metrics [f"{ constants .WMAPE } _{ constants .ALL_SUFFIX } " ]
427- if self .is_holdout
428- else None
429- )
430-
431- @property
432- def r_squared_train (self ) -> float | None :
433- """The R-squared metric for train data."""
434- return (
435- self .metrics [f"{ constants .R_SQUARED } _{ constants .TRAIN_SUFFIX } " ]
436- if self .is_holdout
437- else None
438- )
439-
440- @property
441- def mape_train (self ) -> float | None :
442- """The MAPE metric for train data."""
443- return (
444- self .metrics [f"{ constants .MAPE } _{ constants .TRAIN_SUFFIX } " ]
445- if self .is_holdout
446- else None
447- )
448-
449- @property
450- def wmape_train (self ) -> float | None :
451- """The wMAPE metric for train data."""
452- return (
453- self .metrics [f"{ constants .WMAPE } _{ constants .TRAIN_SUFFIX } " ]
454- if self .is_holdout
455- else None
456- )
457-
458- @property
459- def r_squared_test (self ) -> float | None :
460- """The R-squared metric for test data."""
461- return (
462- self .metrics [f"{ constants .R_SQUARED } _{ constants .TEST_SUFFIX } " ]
463- if self .is_holdout
464- else None
465- )
466-
467- @property
468- def mape_test (self ) -> float | None :
469- """The MAPE metric for test data."""
470- return (
471- self .metrics [f"{ constants .MAPE } _{ constants .TEST_SUFFIX } " ]
472- if self .is_holdout
473- else None
474- )
475-
476- @property
477- def wmape_test (self ) -> float | None :
478- """The wMAPE metric for test data."""
479- return (
480- self .metrics [f"{ constants .WMAPE } _{ constants .TEST_SUFFIX } " ]
481- if self .is_holdout
482- else None
483- )
484390
485391 @property
486392 def details (self ) -> Mapping [str , Any ]:
487393 """The check result details."""
488- return self .metrics
394+ return {
395+ f"{ constants .R_SQUARED } { constants .ALL_SUFFIX } " : self .metrics .r_squared ,
396+ f"{ constants .MAPE } { constants .ALL_SUFFIX } " : self .metrics .mape ,
397+ f"{ constants .WMAPE } { constants .ALL_SUFFIX } " : self .metrics .wmape ,
398+ f"{ constants .R_SQUARED } { constants .TRAIN_SUFFIX } " : (
399+ self .metrics .r_squared_train
400+ ),
401+ f"{ constants .MAPE } { constants .TRAIN_SUFFIX } " : self .metrics .mape_train ,
402+ f"{ constants .WMAPE } { constants .TRAIN_SUFFIX } " : self .metrics .wmape_train ,
403+ f"{ constants .R_SQUARED } { constants .TEST_SUFFIX } " : (
404+ self .metrics .r_squared_test
405+ ),
406+ f"{ constants .MAPE } { constants .TEST_SUFFIX } " : self .metrics .mape_test ,
407+ f"{ constants .WMAPE } { constants .TEST_SUFFIX } " : self .metrics .wmape_test ,
408+ }
489409
490410 @property
491411 def recommendation (self ) -> str :
492412 """The check result message."""
493413 if self .is_holdout :
494414 report_str = (
495- "R-squared = {r_squared_all :.4f} (All),"
415+ "R-squared = {r_squared :.4f} (All),"
496416 " {r_squared_train:.4f} (Train), {r_squared_test:.4f} (Test); MAPE"
497- " = {mape_all :.4f} (All), {mape_train:.4f} (Train),"
498- " {mape_test:.4f} (Test); wMAPE = {wmape_all :.4f} (All),"
417+ " = {mape :.4f} (All), {mape_train:.4f} (Train),"
418+ " {mape_test:.4f} (Test); wMAPE = {wmape :.4f} (All),"
499419 " {wmape_train:.4f} (Train), {wmape_test:.4f} (Test)" .format (
500420 ** self .details
501421 )
0 commit comments