Skip to content

Commit 242fb43

Browse files
andycylmetafacebook-github-bot
authored andcommitted
Support MultiTypeExperiment in Instantiation (#2939)
Summary: Pull Request resolved: #2939 1. **InstatntiationBase:** Add support returning MultiTypeExperiment in InstatntiationBase._make_experiment. 2. **MultiTypeExperiment:** Add add_tracking_metrics function in MultiTypeExperiment to support batch adding metrics when creating a MultiTypeExperiment. 3. **AxClient**: Add support for creating MultiTypeExperiment, add_trial_type and add_tracking_metrics. Reviewed By: sdaulton Differential Revision: D64612495 fbshipit-source-id: d162b8965cfdb516bd4d484d321cff42835b617f
1 parent d983045 commit 242fb43

6 files changed

+273
-9
lines changed

ax/core/multi_type_experiment.py

+42
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
default_trial_type: str,
5050
default_runner: Runner,
5151
optimization_config: OptimizationConfig | None = None,
52+
tracking_metrics: list[Metric] | None = None,
5253
status_quo: Arm | None = None,
5354
description: str | None = None,
5455
is_test: bool = False,
@@ -65,6 +66,7 @@ def __init__(
6566
default_runner: Default runner for trials of the default type.
6667
optimization_config: Optimization config of the experiment.
6768
tracking_metrics: Additional tracking metrics not used for optimization.
69+
These are associated with the default trial type.
6870
runner: Default runner used for trials on this experiment.
6971
status_quo: Arm representing existing "control" arm.
7072
description: Description of the experiment.
@@ -101,6 +103,7 @@ def __init__(
101103
experiment_type=experiment_type,
102104
properties=properties,
103105
default_data_type=default_data_type,
106+
tracking_metrics=tracking_metrics,
104107
)
105108

106109
def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment":
@@ -163,6 +166,45 @@ def add_tracking_metric(
163166
self._metric_to_canonical_name[metric.name] = canonical_name
164167
return self
165168

169+
def add_tracking_metrics(
170+
self,
171+
metrics: list[Metric],
172+
metrics_to_trial_types: dict[str, str] | None = None,
173+
canonical_names: dict[str, str] | None = None,
174+
) -> Experiment:
175+
"""Add a list of new metrics to the experiment.
176+
177+
If any of the metrics are already defined on the experiment,
178+
we raise an error and don't add any of them to the experiment
179+
180+
Args:
181+
metrics: Metrics to be added.
182+
metrics_to_trial_types: The mapping from metric names to corresponding
183+
trial types for each metric. If provided, the metrics will be
184+
added to their trial types. If not provided, then the default
185+
trial type will be used.
186+
canonical_names: A mapping of metric names to their
187+
canonical names(The default metrics for which the metrics are
188+
proxies.)
189+
190+
Returns:
191+
The experiment with the added metrics.
192+
"""
193+
metrics_to_trial_types = metrics_to_trial_types or {}
194+
canonical_name = None
195+
for metric in metrics:
196+
if canonical_names is not None:
197+
canonical_name = none_throws(canonical_names).get(metric.name, None)
198+
199+
self.add_tracking_metric(
200+
metric=metric,
201+
trial_type=metrics_to_trial_types.get(
202+
metric.name, self._default_trial_type
203+
),
204+
canonical_name=canonical_name,
205+
)
206+
return self
207+
166208
# pyre-fixme[14]: `update_tracking_metric` overrides method defined in
167209
# `Experiment` inconsistently.
168210
def update_tracking_metric(

ax/core/tests/test_multi_type_experiment.py

+34
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,40 @@ def test_runner_for_trial_type(self) -> None:
171171
):
172172
self.experiment.runner_for_trial_type(trial_type="invalid")
173173

174+
def test_add_tracking_metrics(self) -> None:
175+
type1_metrics = [
176+
BraninMetric("m3_type1", ["x1", "x2"]),
177+
BraninMetric("m4_type1", ["x1", "x2"]),
178+
]
179+
type2_metrics = [
180+
BraninMetric("m3_type2", ["x1", "x2"]),
181+
BraninMetric("m4_type2", ["x1", "x2"]),
182+
]
183+
default_type_metrics = [
184+
BraninMetric("m5_default_type", ["x1", "x2"]),
185+
]
186+
self.experiment.add_tracking_metrics(
187+
metrics=type1_metrics + type2_metrics + default_type_metrics,
188+
metrics_to_trial_types={
189+
"m3_type1": "type1",
190+
"m4_type1": "type1",
191+
"m3_type2": "type2",
192+
"m4_type2": "type2",
193+
},
194+
)
195+
self.assertDictEqual(
196+
self.experiment._metric_to_trial_type,
197+
{
198+
"m1": "type1",
199+
"m2": "type2",
200+
"m3_type1": "type1",
201+
"m4_type1": "type1",
202+
"m3_type2": "type2",
203+
"m4_type2": "type2",
204+
"m5_default_type": "type1",
205+
},
206+
)
207+
174208

175209
class MultiTypeExperimentUtilsTest(TestCase):
176210
def setUp(self) -> None:

ax/service/ax_client.py

+41-8
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,22 @@
2727
from ax.core.generator_run import GeneratorRun
2828
from ax.core.map_data import MapData
2929
from ax.core.map_metric import MapMetric
30+
from ax.core.multi_type_experiment import MultiTypeExperiment
3031
from ax.core.objective import MultiObjective, Objective
3132
from ax.core.observation import ObservationFeatures
3233
from ax.core.optimization_config import (
3334
MultiObjectiveOptimizationConfig,
3435
OptimizationConfig,
3536
)
37+
from ax.core.runner import Runner
3638
from ax.core.trial import Trial
3739
from ax.core.types import (
3840
TEvaluationOutcome,
3941
TModelPredictArm,
4042
TParameterization,
4143
TParamValue,
4244
)
45+
4346
from ax.core.utils import get_pending_observation_features_based_on_trial_status
4447
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
4548
from ax.early_stopping.utils import estimate_early_stopping_savings
@@ -90,6 +93,7 @@
9093
from ax.utils.common.typeutils import checked_cast
9194
from pyre_extensions import assert_is_instance, none_throws
9295

96+
9397
logger: Logger = get_logger(__name__)
9498

9599

@@ -251,6 +255,8 @@ def create_experiment(
251255
immutable_search_space_and_opt_config: bool = True,
252256
is_test: bool = False,
253257
metric_definitions: dict[str, dict[str, Any]] | None = None,
258+
default_trial_type: str | None = None,
259+
default_runner: Runner | None = None,
254260
) -> None:
255261
"""Create a new experiment and save it if DBSettings available.
256262
@@ -316,6 +322,15 @@ def create_experiment(
316322
to that metric. Note these are modified in-place. Each
317323
Metric must have its own dictionary (metrics cannot share a
318324
single dictionary object).
325+
default_trial_type: The default trial type if multiple
326+
trial types are intended to be used in the experiment. If specified,
327+
a MultiTypeExperiment will be created. Otherwise, a single-type
328+
Experiment will be created.
329+
default_runner: The default runner in this experiment.
330+
This applies to MultiTypeExperiment (when default_trial_type
331+
is specified) and needs to be specified together with
332+
default_trial_type. This will be ignored for single-type Experiment
333+
(when default_trial_type is not specified).
319334
"""
320335
self._validate_early_stopping_strategy(support_intermediate_data)
321336

@@ -344,6 +359,8 @@ def create_experiment(
344359
support_intermediate_data=support_intermediate_data,
345360
immutable_search_space_and_opt_config=immutable_search_space_and_opt_config,
346361
is_test=is_test,
362+
default_trial_type=default_trial_type,
363+
default_runner=default_runner,
347364
**objective_kwargs,
348365
)
349366
self._set_runner(experiment=experiment)
@@ -416,6 +433,8 @@ def add_tracking_metrics(
416433
self,
417434
metric_names: list[str],
418435
metric_definitions: dict[str, dict[str, Any]] | None = None,
436+
metrics_to_trial_types: dict[str, str] | None = None,
437+
canonical_names: dict[str, str] | None = None,
419438
) -> None:
420439
"""Add a list of new metrics to the experiment.
421440
@@ -428,20 +447,34 @@ def add_tracking_metrics(
428447
to that metric. Note these are modified in-place. Each
429448
Metric must have its is own dictionary (metrics cannot share a
430449
single dictionary object).
450+
metrics_to_trial_types: Only applicable to MultiTypeExperiment.
451+
The mapping from metric names to corresponding
452+
trial types for each metric. If provided, the metrics will be
453+
added with their respective trial types. If not provided, then the
454+
default trial type will be used.
455+
canonical_names: A mapping from metric name (of a particular trial type)
456+
to the metric name of the default trial type. Only applicable to
457+
MultiTypeExperiment.
431458
"""
432459
metric_definitions = (
433460
self.metric_definitions
434461
if metric_definitions is None
435462
else metric_definitions
436463
)
437-
self.experiment.add_tracking_metrics(
438-
metrics=[
439-
self._make_metric(
440-
name=metric_name, metric_definitions=metric_definitions
441-
)
442-
for metric_name in metric_names
443-
]
444-
)
464+
metric_objects = [
465+
self._make_metric(name=metric_name, metric_definitions=metric_definitions)
466+
for metric_name in metric_names
467+
]
468+
469+
if isinstance(self.experiment, MultiTypeExperiment):
470+
experiment = assert_is_instance(self.experiment, MultiTypeExperiment)
471+
experiment.add_tracking_metrics(
472+
metrics=metric_objects,
473+
metrics_to_trial_types=metrics_to_trial_types,
474+
canonical_names=canonical_names,
475+
)
476+
else:
477+
self.experiment.add_tracking_metrics(metrics=metric_objects)
445478

446479
@copy_doc(Experiment.remove_tracking_metric)
447480
def remove_tracking_metric(self, metric_name: str) -> None:

ax/service/tests/test_ax_client.py

+105-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ax.core.arm import Arm
2222
from ax.core.generator_run import GeneratorRun
2323
from ax.core.metric import Metric
24+
from ax.core.multi_type_experiment import MultiTypeExperiment
2425
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
2526
from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint
2627
from ax.core.parameter import (
@@ -57,6 +58,7 @@
5758
from ax.modelbridge.model_spec import ModelSpec
5859
from ax.modelbridge.random import RandomModelBridge
5960
from ax.modelbridge.registry import Models
61+
from ax.runners.synthetic import SyntheticRunner
6062

6163
from ax.service.ax_client import AxClient, ObjectiveProperties
6264
from ax.service.utils.best_point import (
@@ -83,7 +85,7 @@
8385
from ax.utils.testing.mock import mock_botorch_optimize
8486
from ax.utils.testing.modeling_stubs import get_observation1, get_observation1trans
8587
from botorch.test_functions.multi_objective import BraninCurrin
86-
from pyre_extensions import none_throws
88+
from pyre_extensions import assert_is_instance, none_throws
8789

8890
if TYPE_CHECKING:
8991
from ax.core.types import TTrialEvaluation
@@ -821,6 +823,7 @@ def test_create_experiment(self) -> None:
821823
is_test=True,
822824
)
823825
assert ax_client._experiment is not None
826+
self.assertEqual(ax_client.experiment.__class__.__name__, "Experiment")
824827
self.assertEqual(ax_client._experiment, ax_client.experiment)
825828
self.assertEqual(
826829
# pyre-fixme[16]: `Optional` has no attribute `search_space`.
@@ -903,6 +906,107 @@ def test_create_experiment(self) -> None:
903906
{"test_objective", "some_metric", "test_tracking_metric"},
904907
)
905908

909+
def test_create_multitype_experiment(self) -> None:
910+
"""
911+
Test create multitype experiment, add trial type, and add metrics to
912+
different trial types
913+
"""
914+
ax_client = AxClient(
915+
GenerationStrategy(
916+
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
917+
)
918+
)
919+
ax_client.create_experiment(
920+
name="test_experiment",
921+
parameters=[
922+
{
923+
"name": "x",
924+
"type": "range",
925+
"bounds": [0.001, 0.1],
926+
"value_type": "float",
927+
"log_scale": True,
928+
"digits": 6,
929+
},
930+
{
931+
"name": "y",
932+
"type": "choice",
933+
"values": [1, 2, 3],
934+
"value_type": "int",
935+
"is_ordered": True,
936+
},
937+
{"name": "x3", "type": "fixed", "value": 2, "value_type": "int"},
938+
{
939+
"name": "x4",
940+
"type": "range",
941+
"bounds": [1.0, 3.0],
942+
"value_type": "int",
943+
},
944+
{
945+
"name": "x5",
946+
"type": "choice",
947+
"values": ["one", "two", "three"],
948+
"value_type": "str",
949+
},
950+
{
951+
"name": "x6",
952+
"type": "range",
953+
"bounds": [1.0, 3.0],
954+
"value_type": "int",
955+
},
956+
],
957+
objectives={"test_objective": ObjectiveProperties(minimize=True)},
958+
outcome_constraints=["some_metric >= 3", "some_metric <= 4.0"],
959+
parameter_constraints=["x4 <= x6"],
960+
tracking_metric_names=["test_tracking_metric"],
961+
is_test=True,
962+
default_trial_type="test_trial_type",
963+
default_runner=SyntheticRunner(),
964+
)
965+
966+
self.assertEqual(ax_client.experiment.__class__.__name__, "MultiTypeExperiment")
967+
experiment = assert_is_instance(ax_client.experiment, MultiTypeExperiment)
968+
self.assertEqual(
969+
experiment._trial_type_to_runner["test_trial_type"].__class__.__name__,
970+
"SyntheticRunner",
971+
)
972+
self.assertEqual(
973+
experiment._metric_to_trial_type,
974+
{
975+
"test_tracking_metric": "test_trial_type",
976+
"test_objective": "test_trial_type",
977+
"some_metric": "test_trial_type",
978+
},
979+
)
980+
experiment.add_trial_type(
981+
trial_type="test_trial_type_2",
982+
runner=SyntheticRunner(),
983+
)
984+
ax_client.add_tracking_metrics(
985+
metric_names=[
986+
"some_metric2_type1",
987+
"some_metric3_type1",
988+
"some_metric4_type2",
989+
"some_metric5_type2",
990+
],
991+
metrics_to_trial_types={
992+
"some_metric2_type1": "test_trial_type",
993+
"some_metric4_type2": "test_trial_type_2",
994+
"some_metric5_type2": "test_trial_type_2",
995+
},
996+
)
997+
self.assertEqual(
998+
experiment._metric_to_trial_type,
999+
{
1000+
"test_tracking_metric": "test_trial_type",
1001+
"test_objective": "test_trial_type",
1002+
"some_metric": "test_trial_type",
1003+
"some_metric2_type1": "test_trial_type",
1004+
"some_metric3_type1": "test_trial_type",
1005+
"some_metric4_type2": "test_trial_type_2",
1006+
"some_metric5_type2": "test_trial_type_2",
1007+
},
1008+
)
1009+
9061010
def test_create_single_objective_experiment_with_objectives_dict(self) -> None:
9071011
ax_client = AxClient(
9081012
GenerationStrategy(

ax/service/tests/test_instantiation_utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
RangeParameter,
1818
)
1919
from ax.core.search_space import HierarchicalSearchSpace
20+
from ax.runners.synthetic import SyntheticRunner
2021
from ax.service.utils.instantiation import InstantiationBase
2122
from ax.utils.common.testutils import TestCase
2223
from ax.utils.common.typeutils import checked_cast
@@ -431,3 +432,21 @@ def test_hss(self) -> None:
431432
self.assertIsInstance(search_space, HierarchicalSearchSpace)
432433
# pyre-fixme[16]: `SearchSpace` has no attribute `_root`.
433434
self.assertEqual(search_space._root.name, "root")
435+
436+
def test_make_multitype_experiment_with_default_trial_type(self) -> None:
437+
experiment = InstantiationBase.make_experiment(
438+
name="test_make_experiment",
439+
parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}],
440+
tracking_metric_names=None,
441+
default_trial_type="test_trial_type",
442+
default_runner=SyntheticRunner(),
443+
)
444+
self.assertEqual(experiment.__class__.__name__, "MultiTypeExperiment")
445+
446+
def test_make_single_type_experiment_with_no_default_trial_type(self) -> None:
447+
experiment = InstantiationBase.make_experiment(
448+
name="test_make_experiment",
449+
parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}],
450+
tracking_metric_names=None,
451+
)
452+
self.assertEqual(experiment.__class__.__name__, "Experiment")

0 commit comments

Comments
 (0)