Skip to content

Commit

Permalink
Don't assume order is the same in observed and predicted features
Browse files Browse the repository at this point in the history
Summary: This is causing only 3 of 14 arms to show on an experiment.  We're still assuming that within an observation the order of the metric names matches the order of the corresponding data.

Differential Revision: D68274239
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jan 16, 2025
1 parent ca93faa commit 6a13a75
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,34 +151,26 @@ def _prepare_data(

records = []
for observed, predicted in cv_results:
if trial_index is not None:
if (
observed.features.trial_index is not None
and observed.features.trial_index >= trial_index
):
raise UserInputError(
"CrossValidationPlot was specified to be for the generation of "
f"trial {trial_index}, but has observations from trial "
f"{observed.features.trial_index}."
)
for i in range(len(observed.data.metric_names)):
# Find the index of the metric we want to plot
if not (
observed.data.metric_names[i] == metric_name
and predicted.metric_names[i] == metric_name
):
continue

record = {
"arm_name": observed.arm_name,
"observed": observed.data.means[i],
"predicted": predicted.means[i],
# Take the square root of the the SEM to get the standard deviation
"observed_sem": observed.data.covariance[i][i] ** 0.5,
"predicted_sem": predicted.covariance[i][i] ** 0.5,
}
records.append(record)
break
observed_i = next(
i
for i in range(len(observed.data.metric_names))
if observed.data.metric_names[i] == metric_name
)
predicted_i = next(
i
for i in range(len(observed.data.metric_names))
if predicted.metric_names[i] == metric_name
)

record = {
"arm_name": observed.arm_name,
"observed": observed.data.means[observed_i],
"predicted": predicted.means[predicted_i],
# Take the square root of the variance to get the standard deviation
"observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5,
"predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5,
}
records.append(record)

return pd.DataFrame.from_records(records)

Expand Down

0 comments on commit 6a13a75

Please sign in to comment.