Skip to content

Commit 643b979

Browse files
mpolson64facebook-github-bot
authored andcommitted
Add trial_type support to base Experiment metric methods (#5002)
Summary: Phase 2 of moving MultiTypeExperiment features into base Experiment. Updates the base Experiment metric management methods (`add_metric`, `update_metric`, `remove_metric`) to accept an optional `trial_type` parameter. When provided, metrics are associated with the specified trial type in `_trial_type_to_metric_names`. The `__init__` and `optimization_config` setter also now register metrics when `default_trial_type` is set. The deprecated wrappers (`add_tracking_metric`, `add_tracking_metrics`, `update_tracking_metric`) now accept and pass through `trial_type` and `canonical_name` parameters. On MultiTypeExperiment, overrides are simplified to delegate to the base class methods: - `add_tracking_metric` delegates to `self.add_metric()` - `add_tracking_metrics` override removed (inherited from base) - `update_tracking_metric` delegates to `self.update_metric()` - `remove_tracking_metric` replaced with `remove_metric` override Differential Revision: D94986440
1 parent ac76c57 commit 643b979

3 files changed

Lines changed: 117 additions & 110 deletions

File tree

ax/core/experiment.py

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def __init__(
202202
# a naming collision occurs.
203203
for m in [*(tracking_metrics or []), *(metrics or [])]:
204204
self._metrics[m.name] = m
205+
if self._default_trial_type is not None:
206+
self._trial_type_to_metric_names.setdefault(
207+
self._default_trial_type, set()
208+
).add(m.name)
205209

206210
# call setters defined below
207211
self.status_quo = status_quo
@@ -590,6 +594,12 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
590594
"but not found on experiment. Add it first with add_metric()."
591595
)
592596
self._optimization_config = optimization_config
597+
resolved_trial_type = self._resolve_trial_type(None)
598+
if resolved_trial_type is not None:
599+
for metric_name in optimization_config.metric_names:
600+
self._trial_type_to_metric_names.setdefault(
601+
resolved_trial_type, set()
602+
).add(metric_name)
593603

594604
@property
595605
def is_moo_problem(self) -> bool:
@@ -842,7 +852,41 @@ def get_metric(self, name: str) -> Metric:
842852
)
843853
return self._metrics[name]
844854

845-
def add_metric(self, metric: Metric) -> Self:
855+
def _resolve_trial_type(self, trial_type: str | None) -> str | None:
856+
"""Resolve an explicit or default trial type and validate it.
857+
858+
Returns ``trial_type`` if explicitly provided (after validating via
859+
``supports_trial_type``), falls back to ``_default_trial_type`` when
860+
available, and raises ``ValueError`` if this experiment uses trial types
861+
(``_trial_type_to_metric_names`` is non-empty) but none could be
862+
resolved.
863+
864+
Args:
865+
trial_type: The explicitly provided trial type, or ``None``.
866+
867+
Returns:
868+
The resolved trial type, which may be ``None`` for single-type
869+
experiments.
870+
871+
Raises:
872+
ValueError: If ``trial_type`` is provided but not supported, or if
873+
no trial type could be resolved for a multi-type experiment.
874+
"""
875+
if trial_type is not None:
876+
if not self.supports_trial_type(trial_type):
877+
raise ValueError(f"`{trial_type}` is not a supported trial type.")
878+
return trial_type
879+
if self._default_trial_type is not None:
880+
return self._default_trial_type
881+
if self._trial_type_to_metric_names:
882+
raise ValueError(
883+
"This experiment has trial-type-aware metrics but no "
884+
"`trial_type` was specified and no `default_trial_type` is set. "
885+
"Please specify a `trial_type`."
886+
)
887+
return None
888+
889+
def add_metric(self, metric: Metric, trial_type: str | None = None) -> Self:
846890
"""Add a new metric to the experiment.
847891
848892
Metrics that are not referenced by the experiment's optimization config
@@ -851,54 +895,98 @@ def add_metric(self, metric: Metric) -> Self:
851895
852896
Args:
853897
metric: Metric to be added.
898+
trial_type: If provided, associates the metric with this trial type.
899+
When ``None`` and a ``default_trial_type`` is set, defaults to
900+
the default trial type.
901+
902+
Raises:
903+
ValueError: If the metric already exists, the trial type is not
904+
supported, or trial types are in use but none could be resolved.
854905
"""
855906
if metric.name in self._metrics:
856907
raise ValueError(
857908
f"Metric `{metric.name}` already defined on experiment. "
858909
"Use `update_metric` to update an existing metric definition."
859910
)
911+
trial_type = self._resolve_trial_type(trial_type)
912+
if trial_type is not None:
913+
self._trial_type_to_metric_names.setdefault(trial_type, set()).add(
914+
metric.name
915+
)
860916
self._metrics[metric.name] = metric
861917
return self
862918

863-
def add_tracking_metric(self, metric: Metric) -> Self:
919+
def add_tracking_metric(
920+
self,
921+
metric: Metric,
922+
trial_type: str | None = None,
923+
canonical_name: str | None = None,
924+
) -> Self:
864925
"""*Deprecated.* Use ``add_metric`` instead."""
865926
warnings.warn(
866927
"add_tracking_metric is deprecated. Use add_metric instead.",
867928
DeprecationWarning,
868929
stacklevel=2,
869930
)
870-
return self.add_metric(metric)
931+
return self.add_metric(metric, trial_type=trial_type)
871932

872-
def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment:
933+
def add_tracking_metrics(
934+
self,
935+
metrics: list[Metric],
936+
metrics_to_trial_types: dict[str, str] | None = None,
937+
canonical_names: dict[str, str] | None = None,
938+
) -> Experiment:
873939
"""*Deprecated.* Use ``add_metric`` instead."""
874940
warnings.warn(
875941
"add_tracking_metrics is deprecated. Use add_metric instead.",
876942
DeprecationWarning,
877943
stacklevel=2,
878944
)
945+
metrics_to_trial_types = metrics_to_trial_types or {}
879946
for metric in metrics:
880-
self.add_metric(metric)
947+
canonical_name = (canonical_names or {}).get(metric.name)
948+
self.add_tracking_metric(
949+
metric=metric,
950+
trial_type=metrics_to_trial_types.get(metric.name),
951+
canonical_name=canonical_name,
952+
)
881953
return self
882954

883-
def update_metric(self, metric: Metric) -> Self:
955+
def update_metric(self, metric: Metric, trial_type: str | None = None) -> Self:
884956
"""Redefine a metric that already exists on the experiment.
885957
886958
Args:
887959
metric: New metric definition.
960+
trial_type: If provided, reassociates the metric with this trial
961+
type. When ``None``, keeps the metric's existing trial type.
888962
"""
889963
if metric.name not in self._metrics:
890964
raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.")
965+
if trial_type is not None:
966+
trial_type = self._resolve_trial_type(trial_type)
967+
# Remove from any existing trial type set
968+
for names in self._trial_type_to_metric_names.values():
969+
names.discard(metric.name)
970+
# Add to new trial type set
971+
self._trial_type_to_metric_names.setdefault(trial_type, set()).add(
972+
metric.name
973+
)
891974
self._metrics[metric.name] = metric
892975
return self
893976

894-
def update_tracking_metric(self, metric: Metric) -> Experiment:
977+
def update_tracking_metric(
978+
self,
979+
metric: Metric,
980+
trial_type: str | None = None,
981+
canonical_name: str | None = None,
982+
) -> Experiment:
895983
"""*Deprecated.* Use ``update_metric`` instead."""
896984
warnings.warn(
897985
"update_tracking_metric is deprecated. Use update_metric instead.",
898986
DeprecationWarning,
899987
stacklevel=2,
900988
)
901-
return self.update_metric(metric)
989+
return self.update_metric(metric, trial_type=trial_type)
902990

903991
def remove_metric(self, metric_name: str) -> Self:
904992
"""Remove a metric from the experiment.
@@ -919,6 +1007,9 @@ def remove_metric(self, metric_name: str) -> Self:
9191007
f"Metric `{metric_name}` is referenced by the optimization config "
9201008
"and cannot be removed. Update the optimization config first."
9211009
)
1010+
# Clean up _trial_type_to_metric_names
1011+
for names in self._trial_type_to_metric_names.values():
1012+
names.discard(metric_name)
9221013
del self._metrics[metric_name]
9231014
return self
9241015

ax/core/multi_type_experiment.py

Lines changed: 18 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,13 @@ def __init__(
9696
default_data_type=default_data_type,
9797
)
9898

99-
# Ensure tracking metrics are registered in _metric_to_trial_type
100-
# and _trial_type_to_metric_names.
101-
# super().__init__ sets self._metrics directly, bypassing
102-
# add_tracking_metric, so tracking metrics won't be in
103-
# _metric_to_trial_type yet.
99+
# Ensure tracking metrics are registered in _metric_to_trial_type.
100+
# The base __init__ handles _trial_type_to_metric_names.
104101
for m in tracking_metrics or []:
105102
if m.name not in self._metric_to_trial_type:
106-
tt = none_throws(self._default_trial_type)
107-
self._metric_to_trial_type[m.name] = tt
108-
self._trial_type_to_metric_names.setdefault(tt, set()).add(m.name)
103+
self._metric_to_trial_type[m.name] = none_throws(
104+
self._default_trial_type
105+
)
109106

110107
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
111108
"""Add a new trial_type to be supported by this experiment.
@@ -127,12 +124,11 @@ def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
127124
def optimization_config(self, optimization_config: OptimizationConfig) -> None:
128125
# pyre-fixme[16]: `Optional` has no attribute `fset`.
129126
Experiment.optimization_config.fset(self, optimization_config)
127+
# Base setter handles _trial_type_to_metric_names; update legacy dict.
130128
for metric_name in optimization_config.metric_names:
131-
# Optimization config metrics are required to be the default trial type
132-
# currently. TODO: remove that restriction (T202797235)
133-
tt = none_throws(self.default_trial_type)
134-
self._metric_to_trial_type[metric_name] = tt
135-
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric_name)
129+
self._metric_to_trial_type[metric_name] = none_throws(
130+
self.default_trial_type
131+
)
136132

137133
def update_runner(self, trial_type: str, runner: Runner) -> Self:
138134
"""Update the default runner for an existing trial_type.
@@ -163,56 +159,12 @@ def add_tracking_metric(
163159
"""
164160
if trial_type is None:
165161
trial_type = self._default_trial_type
166-
if not self.supports_trial_type(trial_type):
167-
raise ValueError(f"`{trial_type}` is not a supported trial type.")
168-
169-
super().add_tracking_metric(metric)
170-
tt = none_throws(trial_type)
171-
self._metric_to_trial_type[metric.name] = tt
172-
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
162+
self.add_metric(metric, trial_type=trial_type)
163+
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
173164
if canonical_name is not None:
174165
self._metric_to_canonical_name[metric.name] = canonical_name
175166
return self
176167

177-
def add_tracking_metrics(
178-
self,
179-
metrics: list[Metric],
180-
metrics_to_trial_types: dict[str, str] | None = None,
181-
canonical_names: dict[str, str] | None = None,
182-
) -> Experiment:
183-
"""Add a list of new metrics to the experiment.
184-
185-
If any of the metrics are already defined on the experiment,
186-
we raise an error and don't add any of them to the experiment
187-
188-
Args:
189-
metrics: Metrics to be added.
190-
metrics_to_trial_types: The mapping from metric names to corresponding
191-
trial types for each metric. If provided, the metrics will be
192-
added to their trial types. If not provided, then the default
193-
trial type will be used.
194-
canonical_names: A mapping of metric names to their
195-
canonical names(The default metrics for which the metrics are
196-
proxies.)
197-
198-
Returns:
199-
The experiment with the added metrics.
200-
"""
201-
metrics_to_trial_types = metrics_to_trial_types or {}
202-
canonical_name = None
203-
for metric in metrics:
204-
if canonical_names is not None:
205-
canonical_name = none_throws(canonical_names).get(metric.name, None)
206-
207-
self.add_tracking_metric(
208-
metric=metric,
209-
trial_type=metrics_to_trial_types.get(
210-
metric.name, self._default_trial_type
211-
),
212-
canonical_name=canonical_name,
213-
)
214-
return self
215-
216168
def update_tracking_metric(
217169
self,
218170
metric: Metric,
@@ -233,47 +185,17 @@ def update_tracking_metric(
233185
trial_type = self._metric_to_trial_type.get(
234186
metric.name, self._default_trial_type
235187
)
236-
oc = self.optimization_config
237-
oc_metric_names = oc.metric_names if oc else set()
238-
if metric.name in oc_metric_names and trial_type != self._default_trial_type:
239-
raise ValueError(
240-
f"Metric `{metric.name}` must remain a "
241-
f"`{self._default_trial_type}` metric because it is part of the "
242-
"optimization_config."
243-
)
244-
elif not self.supports_trial_type(trial_type):
245-
raise ValueError(f"`{trial_type}` is not a supported trial type.")
246-
247-
super().update_tracking_metric(metric)
248-
# Remove from old trial type set
249-
old_tt = self._metric_to_trial_type.get(metric.name)
250-
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
251-
self._trial_type_to_metric_names[old_tt].discard(metric.name)
252-
# Add to new trial type set
253-
tt = none_throws(trial_type)
254-
self._metric_to_trial_type[metric.name] = tt
255-
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
188+
self.update_metric(metric, trial_type=trial_type)
189+
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
256190
if canonical_name is not None:
257191
self._metric_to_canonical_name[metric.name] = canonical_name
258192
return self
259193

260-
@copy_doc(Experiment.remove_tracking_metric)
261-
def remove_tracking_metric(self, metric_name: str) -> Self:
262-
if metric_name not in self._metrics:
263-
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")
264-
265-
# Clean up _trial_type_to_metric_names
266-
old_tt = self._metric_to_trial_type.get(metric_name)
267-
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
268-
self._trial_type_to_metric_names[old_tt].discard(metric_name)
269-
270-
# Required fields
271-
del self._metrics[metric_name]
272-
del self._metric_to_trial_type[metric_name]
273-
274-
# Optional
275-
if metric_name in self._metric_to_canonical_name:
276-
del self._metric_to_canonical_name[metric_name]
194+
@copy_doc(Experiment.remove_metric)
195+
def remove_metric(self, metric_name: str) -> Self:
196+
super().remove_metric(metric_name)
197+
self._metric_to_trial_type.pop(metric_name, None)
198+
self._metric_to_canonical_name.pop(metric_name, None)
277199
return self
278200

279201
@copy_doc(Experiment.fetch_data)

ax/core/tests/test_multi_type_experiment.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,6 @@ def test_BadBehavior(self) -> None:
125125
with self.assertRaises(ValueError):
126126
self.experiment.remove_tracking_metric("m3")
127127

128-
# Try to change optimization metric to non-primary trial type
129-
with self.assertRaises(ValueError):
130-
self.experiment.update_tracking_metric(
131-
BraninMetric("m1", ["x1", "x2"]), "type2"
132-
)
133-
134128
# Update metric definition for trial_type that doesn't exist
135129
with self.assertRaises(ValueError):
136130
self.experiment.update_tracking_metric(

0 commit comments

Comments
 (0)