From 6a13a7516af25072075c3d6a1f80ceef64ac8481 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Thu, 16 Jan 2025 10:49:35 -0800 Subject: [PATCH] Don't assume order is the same in observed and predicted features Summary: This is causing only 3 of 14 arms to show on an experiment. We're still assuming that within an observation the order of the metric names matches the order of the corresponding data. Differential Revision: D68274239 --- ax/analysis/plotly/cross_validation.py | 48 +++++++++++--------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index 4214a5156e0..57e609a03a9 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -151,34 +151,26 @@ def _prepare_data( records = [] for observed, predicted in cv_results: - if trial_index is not None: - if ( - observed.features.trial_index is not None - and observed.features.trial_index >= trial_index - ): - raise UserInputError( - "CrossValidationPlot was specified to be for the generation of " - f"trial {trial_index}, but has observations from trial " - f"{observed.features.trial_index}." - ) - for i in range(len(observed.data.metric_names)): - # Find the index of the metric we want to plot - if not ( - observed.data.metric_names[i] == metric_name - and predicted.metric_names[i] == metric_name - ): - continue - - record = { - "arm_name": observed.arm_name, - "observed": observed.data.means[i], - "predicted": predicted.means[i], - # Take the square root of the the SEM to get the standard deviation - "observed_sem": observed.data.covariance[i][i] ** 0.5, - "predicted_sem": predicted.covariance[i][i] ** 0.5, - } - records.append(record) - break + observed_i = next( + i + for i in range(len(observed.data.metric_names)) + if observed.data.metric_names[i] == metric_name + ) + predicted_i = next( + i + for i in range(len(observed.data.metric_names)) + if predicted.metric_names[i] == metric_name + ) + + record = { + "arm_name": observed.arm_name, + "observed": observed.data.means[observed_i], + "predicted": predicted.means[predicted_i], + # Take the square root of the variance to get the standard deviation + "observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5, + "predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5, + } + records.append(record) return pd.DataFrame.from_records(records)