Skip to content

Commit 1145c41

Browse files
eonofreymeta-codesync[bot]
authored andcommitted
Fix Bug in Slice/Contour Plot (#4969)
Summary: Pull Request resolved: #4969 Fix a bug that casted BatchTrials to Trials for all experiments for slice/contour plots Reviewed By: mgarrard Differential Revision: D94983088 fbshipit-source-id: e8ea8811bde8a00fcc6bd6f926947ac0e377c183
1 parent e2a7c7a commit 1145c41

2 files changed

Lines changed: 67 additions & 4 deletions

File tree

ax/analysis/plotly/surface/tests/test_surface_utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
from unittest.mock import patch
99

10-
from ax.analysis.plotly.surface.utils import get_fixed_values_for_slice_or_contour
10+
from ax.analysis.plotly.surface.utils import (
11+
_get_best_trial_info,
12+
get_fixed_values_for_slice_or_contour,
13+
)
1114
from ax.core.arm import Arm
1215
from ax.service.ax_client import AxClient, ObjectiveProperties
1316
from ax.utils.common.testutils import TestCase
@@ -121,3 +124,50 @@ def test_returns_center_when_best_point_returns_none(self) -> None:
121124

122125
self.assertEqual(fixed_values, {"x": 0.0, "y": 0.0})
123126
self.assertEqual(description, "the center of the search space")
127+
128+
def test_get_best_trial_info_with_batch_trial(self) -> None:
129+
"""Test that _get_best_trial_info works with BatchTrial (PTS experiments),
130+
both when a matching arm is found and when no arm matches."""
131+
client = self._create_client()
132+
experiment = client.experiment
133+
trial = experiment.new_batch_trial()
134+
trial.add_arms_and_weights(
135+
arms=[
136+
Arm(parameters={"x": 0.1, "y": 0.2}, name="0_0"),
137+
Arm(parameters={"x": 0.3, "y": 0.4}, name="0_1"),
138+
Arm(parameters={"x": 0.5, "y": 0.6}, name="0_2"),
139+
],
140+
)
141+
142+
# Subtest 1: matching arm found
143+
with self.subTest("matching_arm_found"):
144+
with patch(
145+
"ax.analysis.plotly.surface.utils."
146+
"get_best_parameters_from_model_predictions_with_trial_index",
147+
return_value=(0, {"x": 0.3, "y": 0.4}, None),
148+
):
149+
result = _get_best_trial_info(
150+
experiment=experiment,
151+
generation_strategy=client.generation_strategy,
152+
)
153+
154+
self.assertIsNotNone(result)
155+
# pyre-ignore[16]: result is not None per assertion above
156+
parameterization, trial_index, arm_name = result
157+
self.assertEqual(parameterization, {"x": 0.3, "y": 0.4})
158+
self.assertEqual(trial_index, 0)
159+
self.assertEqual(arm_name, "0_1")
160+
161+
# Subtest 2: no matching arm returns None
162+
with self.subTest("no_matching_arm_returns_none"):
163+
with patch(
164+
"ax.analysis.plotly.surface.utils."
165+
"get_best_parameters_from_model_predictions_with_trial_index",
166+
return_value=(0, {"x": 0.9, "y": 0.9}, None),
167+
):
168+
result = _get_best_trial_info(
169+
experiment=experiment,
170+
generation_strategy=client.generation_strategy,
171+
)
172+
173+
self.assertIsNone(result)

ax/analysis/plotly/surface/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import TYPE_CHECKING
1111

1212
import numpy as np
13+
from ax.core.batch_trial import BatchTrial
1314
from ax.core.observation import ObservationFeatures
1415
from ax.core.parameter import (
1516
ChoiceParameter,
@@ -68,9 +69,21 @@ def _get_best_trial_info(
6869
return None
6970

7071
trial_index, parameterization, _prediction = result
71-
# Get the arm name from the trial
72-
trial = assert_is_instance(experiment.trials[trial_index], Trial)
73-
arm_name = none_throws(trial.arm).name
72+
# Get the arm name from the trial. Handle both single-arm Trial and
73+
# multi-arm BatchTrial (used by PTS experiments).
74+
base_trial = experiment.trials[trial_index]
75+
if isinstance(base_trial, Trial):
76+
arm_name = none_throws(base_trial.arm).name
77+
else:
78+
# BatchTrial: match the best parameterization to the corresponding arm.
79+
batch_trial = assert_is_instance(base_trial, BatchTrial)
80+
arm_name = None
81+
for arm in batch_trial.arms:
82+
if arm.parameters == parameterization:
83+
arm_name = arm.name
84+
break
85+
if arm_name is None:
86+
return None
7487
return parameterization, trial_index, arm_name
7588

7689

0 commit comments

Comments
 (0)