Skip to content

Commit 7a1c2f0

Browse files
eonofreyfacebook-github-bot
authored andcommitted
Don't assume order and length is the same between observed and predicted features
Summary: Combining diffs D68274239 & D68294872: 1 - **D68274239 - [analysis] Don't assume order is the same in observed and predicted features** This is causing only 3 of 14 arms to show on an experiment. We're still assuming that within an observation the order of the metric names matches the order of the corresponding data. 2 - **D68294872 - [analysis] Don't assume observed and predicted metrics are the same length** Comment from [here](https://www.internalfb.com/diff/D68274239?dst_version_fbid=9435339503152040&transaction_fbid=1415387593232440): In N6432597 I encounter a `StopIteration` error when rebased on D68274239. I believe this is because `predicted.metric_names` is quite a bit longer than `observed.data.metric_names` and so there's a chance that ``` predicted_i = next( i for i in range(len(observed.data.metric_names)) if predicted.metric_names[i] == metric_name ) ``` never finds the metric it needs and throws the `StopIteration` error. Differential Revision: D68336952
1 parent 81fba5c commit 7a1c2f0

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

ax/analysis/plotly/cross_validation.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,25 +161,30 @@ def _prepare_data(
161161
f"trial {trial_index}, but has observations from trial "
162162
f"{observed.features.trial_index}."
163163
)
164-
for i in range(len(observed.data.metric_names)):
165-
# Find the index of the metric we want to plot
166-
if not (
167-
observed.data.metric_names[i] == metric_name
168-
and predicted.metric_names[i] == metric_name
169-
):
170-
continue
171-
164+
# Find the index of the metric in observed and predicted
165+
observed_i = next(
166+
(
167+
i
168+
for i, name in enumerate(observed.data.metric_names)
169+
if name == metric_name
170+
),
171+
None,
172+
)
173+
predicted_i = next(
174+
(i for i, name in enumerate(predicted.metric_names) if name == metric_name),
175+
None,
176+
)
177+
# Check if both indices are found
178+
if observed_i is not None and predicted_i is not None:
172179
record = {
173180
"arm_name": observed.arm_name,
174-
"observed": observed.data.means[i],
175-
"predicted": predicted.means[i],
176-
# Take the square root of the the SEM to get the standard deviation
177-
"observed_sem": observed.data.covariance[i][i] ** 0.5,
178-
"predicted_sem": predicted.covariance[i][i] ** 0.5,
181+
"observed": observed.data.means[observed_i],
182+
"predicted": predicted.means[predicted_i],
183+
# Take the square root of the SEM to get the standard deviation
184+
"observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5,
185+
"predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5,
179186
}
180187
records.append(record)
181-
break
182-
183188
return pd.DataFrame.from_records(records)
184189

185190

0 commit comments

Comments
 (0)