forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_arm_effects.py
More file actions
474 lines (415 loc) · 18.3 KB
/
test_arm_effects.py
File metadata and controls
474 lines (415 loc) · 18.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-safe
import json
from itertools import product
from ax.adapter.registry import Generators
from ax.analysis.plotly.arm_effects import ArmEffectsPlot, compute_arm_effects_adhoc
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
from ax.core.arm import Arm
from ax.core.trial_status import DEFAULT_ANALYSIS_STATUSES, TrialStatus
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
get_non_failed_arm_names,
get_offline_experiments,
get_online_experiments,
)
from ax.utils.testing.mock import mock_botorch_optimize
from ax.utils.testing.modeling_stubs import (
get_default_generation_strategy_at_MBM_node,
get_sobol_MBM_MTGP_gs,
)
from pyre_extensions import assert_is_instance, none_throws
class TestArmEffectsPlot(TestCase):
@mock_botorch_optimize
def setUp(self) -> None:
super().setUp()
self.client = Client()
self.client.configure_experiment(
name="test_experiment",
parameters=[
RangeParameterConfig(
name="x1",
parameter_type="float",
bounds=(0, 1),
),
RangeParameterConfig(
name="x2",
parameter_type="float",
bounds=(0, 1),
),
],
)
self.client.configure_optimization(
objective="foo", outcome_constraints=["bar >= -0.5"]
)
# Get two trials and fail one, giving us a ragged structure
self.client.get_next_trials(max_trials=2)
self.client.complete_trial(trial_index=0, raw_data={"foo": 1.0, "bar": 2.0})
self.client.mark_trial_failed(trial_index=1)
# Complete 5 trials successfully
for _ in range(5):
for trial_index, parameterization in self.client.get_next_trials(
max_trials=1
).items():
self.client.complete_trial(
trial_index=trial_index,
raw_data={
"foo": assert_is_instance(parameterization["x1"], float),
"bar": assert_is_instance(parameterization["x1"], float)
- 2 * assert_is_instance(parameterization["x2"], float),
},
)
def test_trial_statuses_behavior(self) -> None:
# When neither trial_statuses nor trial_index is provided,
# should use default statuses (excluding ABANDONED, STALE, and FAILED)
analysis = ArmEffectsPlot(metric_name="foo")
self.assertEqual(
set(none_throws(analysis.trial_statuses)),
DEFAULT_ANALYSIS_STATUSES,
)
# When trial_statuses is explicitly provided, it should be used
explicit_statuses = [TrialStatus.COMPLETED, TrialStatus.RUNNING]
analysis = ArmEffectsPlot(metric_name="foo", trial_statuses=explicit_statuses)
self.assertEqual(analysis.trial_statuses, explicit_statuses)
# When trial_index is provided (and trial_statuses is None),
# trial_statuses should be None to allow filtering by trial_index
analysis = ArmEffectsPlot(metric_name="foo", trial_index=0)
self.assertIsNone(analysis.trial_statuses)
def test_validation(self) -> None:
with self.assertRaisesRegex(
UserInputError, "Requested metrics .* are not present in the experiment."
):
ArmEffectsPlot(metric_name="baz").compute(
experiment=self.client._experiment,
generation_strategy=self.client._generation_strategy,
)
with self.assertRaisesRegex(
UserInputError, "Trial with index .* not found in experiment."
):
ArmEffectsPlot(metric_name="foo", trial_index=1998).compute(
experiment=self.client._experiment,
generation_strategy=self.client._generation_strategy,
)
def test_compute_raw(self) -> None:
default_analysis = ArmEffectsPlot(
metric_name="foo", use_model_predictions=False
)
card = default_analysis.compute(
experiment=self.client._experiment,
generation_strategy=self.client._generation_strategy,
)
self.assertEqual(
set(card.df.columns),
{
"trial_index",
"arm_name",
"trial_status",
"status_reason",
"generation_node",
"foo_mean",
"foo_sem",
},
)
# Check that we have one row per arm from non-failed trials and that each
# arm appears only once
non_failed_arms = get_non_failed_arm_names(self.client._experiment)
self.assertEqual(len(card.df), len(non_failed_arms))
for arm_name in non_failed_arms:
self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1)
# Check that all SEMs are NaN
self.assertTrue(card.df["foo_sem"].isna().all())
def test_compute_with_modeled(self) -> None:
default_analysis = ArmEffectsPlot(metric_name="foo", use_model_predictions=True)
card = default_analysis.compute(
experiment=self.client._experiment,
generation_strategy=self.client._generation_strategy,
)
self.assertEqual(
set(card.df.columns),
{
"trial_index",
"arm_name",
"trial_status",
"status_reason",
"generation_node",
"foo_mean",
"foo_sem",
},
)
# Check that we have one row per arm from non-failed trials and that each
# arm appears only once
non_failed_arms = get_non_failed_arm_names(self.client._experiment)
self.assertEqual(len(card.df), len(non_failed_arms))
for arm_name in non_failed_arms:
self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1)
# Check that all SEMs are not NaN
self.assertFalse(card.df["foo_sem"].isna().any())
def test_compute_adhoc(self) -> None:
# Use the same kwargs for typical and adhoc
kwargs = {
"metric_name": "foo",
"use_model_predictions": True,
"additional_arms": [Arm(parameters={"x1": 0, "x2": 0})],
"label": "f",
}
# pyre-ignore[6]: Unsafe kwargs usage on purpose
analysis = ArmEffectsPlot(**kwargs)
cards = analysis.compute(
experiment=self.client._experiment,
generation_strategy=self.client._generation_strategy,
)
metric_name = assert_is_instance(kwargs.pop("metric_name"), str)
adhoc_cards = compute_arm_effects_adhoc(
experiment=self.client._experiment,
generation_strategy=self.client._generation_strategy,
metric_names=[metric_name],
labels={metric_name: assert_is_instance(kwargs.pop("label"), str)},
# pyre-ignore[6]: Unsafe kwargs usage on purpose
**kwargs,
)
self.assertEqual(cards, adhoc_cards.children[0])
@TestCase.ax_long_test(
reason=(
"Adapter.predict still too slow under @mock_botorch_optimize for this test"
)
)
@mock_botorch_optimize
def test_online(self) -> None:
# Test ArmEffectsPlot can be computed for a variety of experiments which
# resemble those we see in an online setting.
for experiment in get_online_experiments():
arm = Generators.SOBOL(experiment=experiment).gen(n=1).arms[0]
arm.name = "additional_arm"
generation_strategy = get_default_generation_strategy_at_MBM_node(
experiment=experiment
)
generation_strategy.current_node._fit(experiment=experiment)
adapter = none_throws(generation_strategy.adapter)
for (
use_model_predictions,
trial_index,
with_additional_arms,
) in product([True, False], [None, 0], [True, False]):
if use_model_predictions and with_additional_arms:
additional_arms = [arm]
else:
additional_arms = None
for signature in adapter.metric_signatures:
metric_name = adapter._experiment.signature_to_metric[
signature
].name
analysis = ArmEffectsPlot(
metric_name=metric_name,
use_model_predictions=use_model_predictions,
trial_index=trial_index,
additional_arms=additional_arms,
)
card = analysis.compute(
experiment=experiment,
adapter=adapter,
)
if with_additional_arms and use_model_predictions:
# validate that we plotted the additional arm
self.assertIn(
arm.name,
json.loads(card.blob)["layout"]["xaxis"]["ticktext"],
)
@TestCase.ax_long_test(
reason=(
"Adapter.predict still too slow under @mock_botorch_optimize for this test"
)
)
@mock_botorch_optimize
def test_offline(self) -> None:
# Test ArmEffectsPlot can be computed for a variety of experiments which
# resemble those we see in an offline setting.
for experiment in get_offline_experiments():
generation_strategy = get_default_generation_strategy_at_MBM_node(
experiment=experiment
)
generation_strategy.current_node._fit(experiment=experiment)
adapter = none_throws(generation_strategy.adapter)
model_metric_names = [
adapter._experiment.signature_to_metric[signature].name
for signature in adapter.metric_signatures
]
for use_model_predictions in [True, False]:
for trial_index in [None, 0]:
for with_additional_arms in [True, False]:
if use_model_predictions and with_additional_arms:
additional_arms = [
Arm(
parameters={
parameter_name: 0
for parameter_name in (
experiment.search_space.parameters.keys() # noqa E501
)
}
)
]
else:
additional_arms = None
for metric_name in model_metric_names:
analysis = ArmEffectsPlot(
metric_name=metric_name,
use_model_predictions=use_model_predictions,
trial_index=trial_index,
additional_arms=additional_arms,
)
_ = analysis.compute(
experiment=experiment,
adapter=adapter,
)
class TestArmEffectsPlotInfeasibility(TestCase):
def setUp(self) -> None:
super().setUp()
self.client = Client()
self.client.configure_experiment(
name="test_infeasibility",
parameters=[
RangeParameterConfig(
name="x1",
parameter_type="float",
bounds=(0, 1),
),
RangeParameterConfig(
name="x2",
parameter_type="float",
bounds=(0, 1),
),
],
)
# Constraint on "bar" (non-objective) so it stays as an OutcomeConstraint.
self.client.configure_optimization(
objective="foo",
outcome_constraints=["bar >= 0.5"],
)
# Trial data: (foo, bar). Arms with bar < 0.5 are infeasible.
trial_data = [
{"foo": 1.0, "bar": (0.9, 0.01)}, # feasible
{"foo": 0.5, "bar": (0.8, 0.01)}, # feasible
{"foo": 0.8, "bar": (0.1, 0.01)}, # infeasible
{"foo": 0.3, "bar": (0.2, 0.01)}, # infeasible
]
for raw_data in trial_data:
for trial_index, _ in self.client.get_next_trials(max_trials=1).items():
self.client.complete_trial(trial_index=trial_index, raw_data=raw_data)
def test_infeasible_arms_have_red_outline(self) -> None:
card = ArmEffectsPlot(metric_name="bar", use_model_predictions=False).compute(
experiment=self.client._experiment,
generation_strategy=self.client._generation_strategy,
)
fig_data = json.loads(none_throws(card.blob))
# All infeasible arms are in a single trace with legendgroup="infeasible"
infeasible_traces = [
t for t in fig_data["data"] if t.get("legendgroup") == "infeasible"
]
# Single trace containing all infeasible arms
self.assertEqual(len(infeasible_traces), 1)
trace = infeasible_traces[0]
self.assertEqual(trace["marker"]["line"]["color"], "red")
self.assertGreater(trace["marker"]["line"]["width"], 0)
# The trace should contain 2 infeasible points
self.assertEqual(len([x for x in trace["x"] if x is not None]), 2)
# Legend entry is on the same trace
self.assertTrue(trace["showlegend"])
self.assertEqual(trace["legendgroup"], "infeasible")
def test_no_infeasible_legend_when_all_feasible(self) -> None:
# Create an experiment where all arms satisfy the constraint
client = Client()
client.configure_experiment(
name="all_feasible",
parameters=[
RangeParameterConfig(
name="x1",
parameter_type="float",
bounds=(0, 1),
),
RangeParameterConfig(
name="x2",
parameter_type="float",
bounds=(0, 1),
),
],
)
client.configure_optimization(
objective="foo",
outcome_constraints=["bar >= 0.5"],
)
for bar_val in [0.9, 0.8, 0.7, 0.6]:
for trial_index, _ in client.get_next_trials(max_trials=1).items():
client.complete_trial(
trial_index=trial_index,
raw_data={"foo": 1.0, "bar": (bar_val, 0.01)},
)
card = ArmEffectsPlot(metric_name="bar", use_model_predictions=False).compute(
experiment=client._experiment,
generation_strategy=client._generation_strategy,
)
fig_data = json.loads(none_throws(card.blob))
legend_traces = [
t for t in fig_data["data"] if t.get("name") == "Likely Infeasible"
]
self.assertEqual(len(legend_traces), 0)
class TestArmEffectsPlotRel(TestCase):
def setUp(self) -> None:
super().setUp()
self.experiment = get_branin_experiment(with_status_quo=True)
self.generation_strategy = get_sobol_MBM_MTGP_gs()
self.generation_strategy.experiment = self.experiment
# Run 2 trials
for _ in range(2):
self.experiment.new_batch_trial(
generator_runs=self.generation_strategy.gen(
experiment=self.experiment, n=3
)[0]
).add_status_quo_arm(weight=1.0).mark_completed(unsafe=True)
self.experiment.fetch_data()
def test_compute_with_relativize(self) -> None:
for use_model_predictions in [True, False]:
with self.subTest(use_model_predictions=use_model_predictions):
analysis = ArmEffectsPlot(
metric_name="branin",
use_model_predictions=use_model_predictions,
relativize=True,
)
cards = analysis.compute(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
).flatten()
self.assertEqual(len(cards), 1)
self.assertEqual(
set(cards[0].df.columns),
{
"trial_index",
"arm_name",
"trial_status",
"status_reason",
"generation_node",
"branin_mean",
"branin_sem",
},
)
for card in cards:
# Check that we have one row per arm and that each arm appears only
# once. Exclude status_quo since that is repeated between trials
card_arms = card.df[card.df.arm_name != "status_quo"].arm_name
experiment_arms = self.experiment.arms_by_name.copy()
experiment_arms.pop("status_quo")
self.assertEqual(len(card_arms), len(experiment_arms))
for arm_name in experiment_arms:
self.assertEqual((card.df["arm_name"] == arm_name).sum(), 1)
self.assertFalse(card.df["branin_mean"].isna().any())
self.assertFalse(card.df["branin_sem"].isna().any())
# The title should include status quo name when relativize is True
expected_prefix = "Modeled" if use_model_predictions else "Observed"
expected_suffix = 'relative to "status_quo"'
self.assertIn(expected_prefix, card.title)
self.assertIn("Arm Effects on branin", card.title)
self.assertIn(expected_suffix, card.title)