Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions ax/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,33 @@ def configure_metrics(self, metrics: Sequence[IMetric]) -> None:
"""
self._set_metrics(metrics=metrics)

def configure_tracking_metrics(self, metric_names: Sequence[str]) -> None:
"""
Add tracking metrics to the ``Experiment`` by name.

Tracking metrics are metrics that are recorded during the experiment but
are not used as part of the ``OptimizationConfig`` (i.e., they are not
objectives or outcome constraints). Use this method to declare metrics
that you want to track alongside your optimization objectives.

If any of the metrics are already defined on the experiment, they will be
skipped with a warning.

Args:
metric_names: Names of metrics to be added as tracking metrics.

Saves to database on completion if ``storage_config`` is present.
"""
for metric_name in metric_names:
if metric_name in self._experiment.metrics:
logger.warning(
f"Metric {metric_name} already exists on experiment, skipping."
)
continue
self._experiment.add_tracking_metric(metric=Metric(name=metric_name))

self._save_experiment_to_db_if_possible(experiment=self._experiment)

# -------------------- Section 1.2: Set (not API) -------------------------------
def set_experiment(self, experiment: Experiment) -> None:
"""
Expand Down
29 changes: 29 additions & 0 deletions ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,35 @@ def test_configure_metric(self) -> None:
client._experiment.tracking_metrics[0],
)

def test_configure_tracking_metrics(self) -> None:
"""Test adding tracking metrics to an experiment."""
client = Client()

with self.assertRaisesRegex(AssertionError, "Experiment not set"):
client.configure_tracking_metrics(metric_names=["tracking1"])

client.configure_experiment(
parameters=[
RangeParameterConfig(name="x1", parameter_type="float", bounds=(0, 1))
],
name="test_tracking_metrics",
)
client.configure_optimization(objective="objective")

# Test adding tracking metrics
client.configure_tracking_metrics(metric_names=["tracking1", "tracking2"])

# Verify tracking metrics were added
self.assertIn("tracking1", client._experiment.metrics)
self.assertIn("tracking2", client._experiment.metrics)
self.assertEqual(len(client._experiment.tracking_metrics), 2)

# Test that adding an existing metric logs a warning and skips it
with mock.patch("ax.api.client.logger.warning") as mock_logger:
client.configure_tracking_metrics(metric_names=["tracking1"])
mock_logger.assert_called_once()
self.assertIn("already exists", mock_logger.call_args[0][0])

def test_set_experiment(self) -> None:
client = Client()
experiment = get_branin_experiment()
Expand Down