Skip to content

Commit 6a13a75

Browse files
Daniel Cohenfacebook-github-bot
authored andcommitted
Don't assume order is the same in observed and predicted features
Summary: This is causing only 3 of 14 arms to show on an experiment. We're still assuming that within an observation the order of the metric names matches the order of the corresponding data. Differential Revision: D68274239
1 parent ca93faa commit 6a13a75

File tree

1 file changed

+20
-28
lines changed

1 file changed

+20
-28
lines changed

ax/analysis/plotly/cross_validation.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -151,34 +151,26 @@ def _prepare_data(
151151

152152
records = []
153153
for observed, predicted in cv_results:
154-
if trial_index is not None:
155-
if (
156-
observed.features.trial_index is not None
157-
and observed.features.trial_index >= trial_index
158-
):
159-
raise UserInputError(
160-
"CrossValidationPlot was specified to be for the generation of "
161-
f"trial {trial_index}, but has observations from trial "
162-
f"{observed.features.trial_index}."
163-
)
164-
for i in range(len(observed.data.metric_names)):
165-
# Find the index of the metric we want to plot
166-
if not (
167-
observed.data.metric_names[i] == metric_name
168-
and predicted.metric_names[i] == metric_name
169-
):
170-
continue
171-
172-
record = {
173-
"arm_name": observed.arm_name,
174-
"observed": observed.data.means[i],
175-
"predicted": predicted.means[i],
176-
# Take the square root of the the SEM to get the standard deviation
177-
"observed_sem": observed.data.covariance[i][i] ** 0.5,
178-
"predicted_sem": predicted.covariance[i][i] ** 0.5,
179-
}
180-
records.append(record)
181-
break
154+
observed_i = next(
155+
i
156+
for i in range(len(observed.data.metric_names))
157+
if observed.data.metric_names[i] == metric_name
158+
)
159+
predicted_i = next(
160+
i
161+
for i in range(len(observed.data.metric_names))
162+
if predicted.metric_names[i] == metric_name
163+
)
164+
165+
record = {
166+
"arm_name": observed.arm_name,
167+
"observed": observed.data.means[observed_i],
168+
"predicted": predicted.means[predicted_i],
169+
# Take the square root of the variance to get the standard deviation
170+
"observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5,
171+
"predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5,
172+
}
173+
records.append(record)
182174

183175
return pd.DataFrame.from_records(records)
184176

0 commit comments

Comments
 (0)