1313from ax .adapter .base import Adapter
1414from ax .adapter .cross_validation import cross_validate , CVResult
1515from ax .analysis .analysis import Analysis
16+ from ax .analysis .healthcheck .predictable_metrics import DEFAULT_MODEL_FIT_THRESHOLD
1617from ax .analysis .plotly .color_constants import AX_BLUE
1718from ax .analysis .plotly .plotly_analysis import create_plotly_analysis_card
1819from ax .analysis .plotly .utils import get_scatter_point_color , Z_SCORE_95_CI
@@ -106,6 +107,7 @@ def __init__(
106107 self .untransform = untransform
107108 self .trial_index = trial_index
108109 self .labels : dict [str , str ] = {** labels } if labels is not None else {}
110+ self ._r2s : dict [str , float ] = {}
109111
110112 @override
111113 def validate_applicable_state (
@@ -144,6 +146,7 @@ def compute(
144146 relevant_adapter ._experiment .signature_to_metric [signature ].name
145147 for signature in relevant_adapter ._metric_signatures
146148 ]
149+ self ._r2s = {}
147150 for metric_name in self .metric_names or relevant_adapter_metric_names :
148151 df = _prepare_data (
149152 metric_name = metric_name , cv_results = cv_results , adapter = relevant_adapter
@@ -162,6 +165,7 @@ def compute(
162165 y_obs = df ["observed" ].to_numpy (),
163166 y_pred = df ["predicted" ].to_numpy (),
164167 )
168+ self ._r2s [metric_title ] = r_squared
165169
166170 # Define the cross-validation description based on the number of folds
167171 cv_description = (
@@ -202,6 +206,50 @@ def compute(
202206
203207 cards .append (card )
204208
209+ # Create a summary table of R2 values for all metrics
210+ if self ._r2s :
211+ threshold = DEFAULT_MODEL_FIT_THRESHOLD
212+ metric_names_list = list (self ._r2s .keys ())
213+ r2_values = [f"{ v :.2f} " for v in self ._r2s .values ()]
214+ fill_colors = [
215+ "rgba(0, 200, 0, 0.15)" if r2 >= threshold else "white"
216+ for r2 in self ._r2s .values ()
217+ ]
218+ r2_fig = go .Figure (
219+ data = [
220+ go .Table (
221+ columnwidth = [4 , 1 ],
222+ header = {
223+ "values" : ["Metric" , "R\u00b2 " ],
224+ "align" : "left" ,
225+ },
226+ cells = {
227+ "values" : [metric_names_list , r2_values ],
228+ "align" : "left" ,
229+ "fill_color" : [fill_colors , fill_colors ],
230+ },
231+ )
232+ ]
233+ )
234+ r2_card = create_plotly_analysis_card (
235+ name = self .__class__ .__name__ ,
236+ title = "Summary of model fits" ,
237+ subtitle = (
238+ "R\u00b2 (coefficient of determination) measures how well"
239+ " the model predicts each metric. Higher values indicate"
240+ " better model fit. Metrics with R\u00b2 >="
241+ f" { threshold } are highlighted in green."
242+ ),
243+ df = pd .DataFrame (
244+ {
245+ "Metric" : metric_names_list ,
246+ "R\u00b2 " : list (self ._r2s .values ()),
247+ }
248+ ),
249+ fig = r2_fig ,
250+ )
251+ cards .append (r2_card )
252+
205253 return self ._create_analysis_card_group (
206254 title = CV_CARDGROUP_TITLE ,
207255 subtitle = CV_CARDGROUP_SUBTITLE ,
0 commit comments