Skip to content

Commit 227243c

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Auto-register extraneous metrics as tracking metrics in Client.attach_data (facebook#5135)
Summary: Pull Request resolved: facebook#5135 The `Client.attach_data` docstring claims unexpected metric values will be added as tracking metrics, but the code actually raises a `UserInputError`. This diff fixes the implementation to match the docstring: extraneous metrics in `raw_data` are now auto-registered via `configure_tracking_metrics` before the data is passed downstream. This also applies to `Client.complete_trial` since it delegates to `attach_data`. Reviewed By: Cesar-Cardoso Differential Revision: D99386924 fbshipit-source-id: b2be60b72b63d566828dbbd8a01648e38321c962
1 parent 74b42c3 commit 227243c

2 files changed

Lines changed: 12 additions & 13 deletions

File tree

ax/api/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,13 +571,19 @@ def attach_data(
571571
progression: int | None = None,
572572
) -> None:
573573
"""
574-
Attach data without indicating the trial is complete. Missing metrics are,
574+
Attach data without indicating the trial is complete. Missing metrics are
575575
allowed, and unexpected metric values will be added to the Experiment as
576576
tracking metrics.
577577
578578
Saves to database on completion if ``storage_config`` is present.
579579
"""
580580

581+
# Auto-register any metrics present in raw_data but not yet on the
582+
# experiment as tracking metrics, matching the docstring contract.
583+
extra_metrics = set(raw_data.keys()) - set(self._experiment.metrics.keys())
584+
if extra_metrics:
585+
self.configure_tracking_metrics(metric_names=list(extra_metrics))
586+
581587
# If no progression is provided assume the data is not timeseries-like and
582588
# set step=NaN
583589
data_with_progression = [

ax/api/tests/test_client.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ax.core.trial_status import TrialStatus
4242
from ax.core.utils import compute_metric_availability, MetricAvailability
4343
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
44-
from ax.exceptions.core import UnsupportedError, UserInputError
44+
from ax.exceptions.core import UnsupportedError
4545
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
4646
from ax.storage.sqa_store.with_db_settings_base import (
4747
_save_generation_strategy_to_db_if_possible,
@@ -656,21 +656,14 @@ def test_attach_data(self) -> None:
656656
)
657657

658658
# With extra metrics
659-
# Try and attach data for a metric that doesn't exist
660-
with self.assertRaisesRegex(
661-
UserInputError,
662-
"Unable to find the metric signature for one or more metrics.",
663-
):
664-
client.attach_data(
665-
trial_index=trial_index,
666-
raw_data={"foo": 1.0, "bar": 2.0},
667-
)
668-
669-
client.configure_metrics(metrics=[DummyMetric(name="bar")])
659+
# Extraneous metrics should be auto-registered as tracking metrics
660+
self.assertNotIn("bar", client._experiment.metrics)
670661
client.attach_data(
671662
trial_index=trial_index,
672663
raw_data={"foo": 1.0, "bar": 2.0},
673664
)
665+
self.assertIn("bar", client._experiment.metrics)
666+
self.assertIn("bar", [m.name for m in client._experiment.tracking_metrics])
674667
self.assertEqual(
675668
client._experiment.trials[trial_index].status,
676669
TrialStatus.RUNNING,

0 commit comments

Comments
 (0)