Skip to content

Commit 5efe7ce

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Add trial_type support to base Experiment metric methods
Summary: Phase 2 of moving MultiTypeExperiment features into base Experiment. Updates the base Experiment metric management methods (`add_metric`, `update_metric`, `remove_metric`) to accept an optional `trial_type` parameter. When provided, metrics are associated with the specified trial type in `_trial_type_to_metric_names`. The `__init__` and `optimization_config` setter also now register metrics when `default_trial_type` is set. The deprecated wrappers (`add_tracking_metric`, `add_tracking_metrics`, `update_tracking_metric`) now accept and pass through `trial_type` and `canonical_name` parameters. On MultiTypeExperiment, overrides are simplified to delegate to the base class methods: - `add_tracking_metric` delegates to `self.add_metric()` - `add_tracking_metrics` override removed (inherited from base) - `update_tracking_metric` delegates to `self.update_metric()` - `remove_tracking_metric` replaced with `remove_metric` override Differential Revision: D94986440
1 parent 78ef018 commit 5efe7ce

2 files changed

Lines changed: 95 additions & 104 deletions

File tree

ax/core/experiment.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ def __init__(
201201
# a naming collision occurs.
202202
for m in [*(tracking_metrics or []), *(metrics or [])]:
203203
self._metrics[m.name] = m
204+
if self._default_trial_type is not None:
205+
self._trial_type_to_metric_names.setdefault(
206+
self._default_trial_type, set()
207+
).add(m.name)
204208

205209
# call setters defined below
206210
self.status_quo = status_quo
@@ -582,6 +586,12 @@ def optimization_config(self, optimization_config: OptimizationConfig) -> None:
582586
"but not found on experiment. Add it first with add_metric()."
583587
)
584588
self._optimization_config = optimization_config
589+
default_trial_type = self._default_trial_type
590+
if default_trial_type is not None:
591+
for metric_name in optimization_config.metric_names:
592+
self._trial_type_to_metric_names.setdefault(
593+
default_trial_type, set()
594+
).add(metric_name)
585595

586596
@property
587597
def is_moo_problem(self) -> bool:
@@ -834,7 +844,7 @@ def get_metric(self, name: str) -> Metric:
834844
)
835845
return self._metrics[name]
836846

837-
def add_metric(self, metric: Metric) -> Self:
847+
def add_metric(self, metric: Metric, trial_type: str | None = None) -> Self:
838848
"""Add a new metric to the experiment.
839849
840850
Metrics that are not referenced by the experiment's optimization config
@@ -843,54 +853,110 @@ def add_metric(self, metric: Metric) -> Self:
843853
844854
Args:
845855
metric: Metric to be added.
856+
trial_type: If provided, associates the metric with this trial type.
857+
When ``None`` and a ``default_trial_type`` is set, defaults to
858+
the default trial type.
846859
"""
847860
if metric.name in self._metrics:
848861
raise ValueError(
849862
f"Metric `{metric.name}` already defined on experiment. "
850863
"Use `update_metric` to update an existing metric definition."
851864
)
865+
if trial_type is None and self._default_trial_type is not None:
866+
trial_type = self._default_trial_type
867+
if trial_type is not None:
868+
if not self.supports_trial_type(trial_type):
869+
raise ValueError(f"`{trial_type}` is not a supported trial type.")
870+
self._trial_type_to_metric_names.setdefault(trial_type, set()).add(
871+
metric.name
872+
)
852873
self._metrics[metric.name] = metric
853874
return self
854875

855-
def add_tracking_metric(self, metric: Metric) -> Self:
876+
def add_tracking_metric(
877+
self,
878+
metric: Metric,
879+
trial_type: str | None = None,
880+
canonical_name: str | None = None,
881+
) -> Self:
856882
"""*Deprecated.* Use ``add_metric`` instead."""
857883
warnings.warn(
858884
"add_tracking_metric is deprecated. Use add_metric instead.",
859885
DeprecationWarning,
860886
stacklevel=2,
861887
)
862-
return self.add_metric(metric)
888+
return self.add_metric(metric, trial_type=trial_type)
863889

864-
def add_tracking_metrics(self, metrics: list[Metric]) -> Experiment:
890+
def add_tracking_metrics(
891+
self,
892+
metrics: list[Metric],
893+
metrics_to_trial_types: dict[str, str] | None = None,
894+
canonical_names: dict[str, str] | None = None,
895+
) -> Experiment:
865896
"""*Deprecated.* Use ``add_metric`` instead."""
866897
warnings.warn(
867898
"add_tracking_metrics is deprecated. Use add_metric instead.",
868899
DeprecationWarning,
869900
stacklevel=2,
870901
)
902+
metrics_to_trial_types = metrics_to_trial_types or {}
871903
for metric in metrics:
872-
self.add_metric(metric)
904+
canonical_name = (canonical_names or {}).get(metric.name)
905+
self.add_tracking_metric(
906+
metric=metric,
907+
trial_type=metrics_to_trial_types.get(metric.name),
908+
canonical_name=canonical_name,
909+
)
873910
return self
874911

875-
def update_metric(self, metric: Metric) -> Self:
912+
def update_metric(self, metric: Metric, trial_type: str | None = None) -> Self:
876913
"""Redefine a metric that already exists on the experiment.
877914
878915
Args:
879916
metric: New metric definition.
917+
trial_type: If provided, reassociates the metric with this trial
918+
type. When ``None``, keeps the metric's existing trial type.
880919
"""
881920
if metric.name not in self._metrics:
882921
raise ValueError(f"Metric `{metric.name}` doesn't exist on experiment.")
922+
if trial_type is not None:
923+
if not self.supports_trial_type(trial_type):
924+
raise ValueError(f"`{trial_type}` is not a supported trial type.")
925+
oc = self._optimization_config
926+
if (
927+
oc is not None
928+
and metric.name in oc.metric_names
929+
and self._default_trial_type is not None
930+
and trial_type != self._default_trial_type
931+
):
932+
raise ValueError(
933+
f"Metric `{metric.name}` must remain a "
934+
f"`{self._default_trial_type}` metric because it is part of "
935+
"the optimization_config."
936+
)
937+
# Remove from any existing trial type set
938+
for names in self._trial_type_to_metric_names.values():
939+
names.discard(metric.name)
940+
# Add to new trial type set
941+
self._trial_type_to_metric_names.setdefault(trial_type, set()).add(
942+
metric.name
943+
)
883944
self._metrics[metric.name] = metric
884945
return self
885946

886-
def update_tracking_metric(self, metric: Metric) -> Experiment:
947+
def update_tracking_metric(
948+
self,
949+
metric: Metric,
950+
trial_type: str | None = None,
951+
canonical_name: str | None = None,
952+
) -> Experiment:
887953
"""*Deprecated.* Use ``update_metric`` instead."""
888954
warnings.warn(
889955
"update_tracking_metric is deprecated. Use update_metric instead.",
890956
DeprecationWarning,
891957
stacklevel=2,
892958
)
893-
return self.update_metric(metric)
959+
return self.update_metric(metric, trial_type=trial_type)
894960

895961
def remove_metric(self, metric_name: str) -> Self:
896962
"""Remove a metric from the experiment.
@@ -911,6 +977,9 @@ def remove_metric(self, metric_name: str) -> Self:
911977
f"Metric `{metric_name}` is referenced by the optimization config "
912978
"and cannot be removed. Update the optimization config first."
913979
)
980+
# Clean up _trial_type_to_metric_names
981+
for names in self._trial_type_to_metric_names.values():
982+
names.discard(metric_name)
914983
del self._metrics[metric_name]
915984
return self
916985

ax/core/multi_type_experiment.py

Lines changed: 18 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,13 @@ def __init__(
9696
default_data_type=default_data_type,
9797
)
9898

99-
# Ensure tracking metrics are registered in _metric_to_trial_type
100-
# and _trial_type_to_metric_names.
101-
# super().__init__ sets self._metrics directly, bypassing
102-
# add_tracking_metric, so tracking metrics won't be in
103-
# _metric_to_trial_type yet.
99+
# Ensure tracking metrics are registered in _metric_to_trial_type.
100+
# The base __init__ handles _trial_type_to_metric_names.
104101
for m in tracking_metrics or []:
105102
if m.name not in self._metric_to_trial_type:
106-
tt = none_throws(self._default_trial_type)
107-
self._metric_to_trial_type[m.name] = tt
108-
self._trial_type_to_metric_names.setdefault(tt, set()).add(m.name)
103+
self._metric_to_trial_type[m.name] = none_throws(
104+
self._default_trial_type
105+
)
109106

110107
def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
111108
"""Add a new trial_type to be supported by this experiment.
@@ -127,12 +124,11 @@ def add_trial_type(self, trial_type: str, runner: Runner) -> Self:
127124
def optimization_config(self, optimization_config: OptimizationConfig) -> None:
128125
# pyre-fixme[16]: `Optional` has no attribute `fset`.
129126
Experiment.optimization_config.fset(self, optimization_config)
127+
# Base setter handles _trial_type_to_metric_names; update legacy dict.
130128
for metric_name in optimization_config.metric_names:
131-
# Optimization config metrics are required to be the default trial type
132-
# currently. TODO: remove that restriction (T202797235)
133-
tt = none_throws(self.default_trial_type)
134-
self._metric_to_trial_type[metric_name] = tt
135-
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric_name)
129+
self._metric_to_trial_type[metric_name] = none_throws(
130+
self.default_trial_type
131+
)
136132

137133
def update_runner(self, trial_type: str, runner: Runner) -> Self:
138134
"""Update the default runner for an existing trial_type.
@@ -163,56 +159,12 @@ def add_tracking_metric(
163159
"""
164160
if trial_type is None:
165161
trial_type = self._default_trial_type
166-
if not self.supports_trial_type(trial_type):
167-
raise ValueError(f"`{trial_type}` is not a supported trial type.")
168-
169-
super().add_tracking_metric(metric)
170-
tt = none_throws(trial_type)
171-
self._metric_to_trial_type[metric.name] = tt
172-
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
162+
self.add_metric(metric, trial_type=trial_type)
163+
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
173164
if canonical_name is not None:
174165
self._metric_to_canonical_name[metric.name] = canonical_name
175166
return self
176167

177-
def add_tracking_metrics(
178-
self,
179-
metrics: list[Metric],
180-
metrics_to_trial_types: dict[str, str] | None = None,
181-
canonical_names: dict[str, str] | None = None,
182-
) -> Experiment:
183-
"""Add a list of new metrics to the experiment.
184-
185-
If any of the metrics are already defined on the experiment,
186-
we raise an error and don't add any of them to the experiment
187-
188-
Args:
189-
metrics: Metrics to be added.
190-
metrics_to_trial_types: The mapping from metric names to corresponding
191-
trial types for each metric. If provided, the metrics will be
192-
added to their trial types. If not provided, then the default
193-
trial type will be used.
194-
canonical_names: A mapping of metric names to their
195-
canonical names(The default metrics for which the metrics are
196-
proxies.)
197-
198-
Returns:
199-
The experiment with the added metrics.
200-
"""
201-
metrics_to_trial_types = metrics_to_trial_types or {}
202-
canonical_name = None
203-
for metric in metrics:
204-
if canonical_names is not None:
205-
canonical_name = none_throws(canonical_names).get(metric.name, None)
206-
207-
self.add_tracking_metric(
208-
metric=metric,
209-
trial_type=metrics_to_trial_types.get(
210-
metric.name, self._default_trial_type
211-
),
212-
canonical_name=canonical_name,
213-
)
214-
return self
215-
216168
def update_tracking_metric(
217169
self,
218170
metric: Metric,
@@ -233,47 +185,17 @@ def update_tracking_metric(
233185
trial_type = self._metric_to_trial_type.get(
234186
metric.name, self._default_trial_type
235187
)
236-
oc = self.optimization_config
237-
oc_metric_names = oc.metric_names if oc else set()
238-
if metric.name in oc_metric_names and trial_type != self._default_trial_type:
239-
raise ValueError(
240-
f"Metric `{metric.name}` must remain a "
241-
f"`{self._default_trial_type}` metric because it is part of the "
242-
"optimization_config."
243-
)
244-
elif not self.supports_trial_type(trial_type):
245-
raise ValueError(f"`{trial_type}` is not a supported trial type.")
246-
247-
super().update_tracking_metric(metric)
248-
# Remove from old trial type set
249-
old_tt = self._metric_to_trial_type.get(metric.name)
250-
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
251-
self._trial_type_to_metric_names[old_tt].discard(metric.name)
252-
# Add to new trial type set
253-
tt = none_throws(trial_type)
254-
self._metric_to_trial_type[metric.name] = tt
255-
self._trial_type_to_metric_names.setdefault(tt, set()).add(metric.name)
188+
self.update_metric(metric, trial_type=trial_type)
189+
self._metric_to_trial_type[metric.name] = none_throws(trial_type)
256190
if canonical_name is not None:
257191
self._metric_to_canonical_name[metric.name] = canonical_name
258192
return self
259193

260-
@copy_doc(Experiment.remove_tracking_metric)
261-
def remove_tracking_metric(self, metric_name: str) -> Self:
262-
if metric_name not in self._metrics:
263-
raise ValueError(f"Metric `{metric_name}` doesn't exist on experiment.")
264-
265-
# Clean up _trial_type_to_metric_names
266-
old_tt = self._metric_to_trial_type.get(metric_name)
267-
if old_tt is not None and old_tt in self._trial_type_to_metric_names:
268-
self._trial_type_to_metric_names[old_tt].discard(metric_name)
269-
270-
# Required fields
271-
del self._metrics[metric_name]
272-
del self._metric_to_trial_type[metric_name]
273-
274-
# Optional
275-
if metric_name in self._metric_to_canonical_name:
276-
del self._metric_to_canonical_name[metric_name]
194+
@copy_doc(Experiment.remove_metric)
195+
def remove_metric(self, metric_name: str) -> Self:
196+
super().remove_metric(metric_name)
197+
self._metric_to_trial_type.pop(metric_name, None)
198+
self._metric_to_canonical_name.pop(metric_name, None)
277199
return self
278200

279201
@copy_doc(Experiment.fetch_data)

0 commit comments

Comments
 (0)