Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reap rarely used/unused methods and properties #3333

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 17 additions & 106 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import defaultdict, OrderedDict
from collections.abc import Hashable, Iterable, Mapping
from datetime import datetime
from functools import partial, reduce
from functools import partial

from typing import Any, cast

Expand Down Expand Up @@ -337,18 +337,17 @@ def arms_by_signature_for_deduplication(self) -> dict[str, Arm]:
return arms_dict

@property
def sum_trial_sizes(self) -> int:
"""Sum of numbers of arms attached to each trial in this experiment."""
return reduce(lambda a, b: a + len(b.arms_by_name), self._trials.values(), 0)
def metrics(self) -> dict[str, Metric]:
"""The metrics attached to the experiment."""
optimization_config_metrics: dict[str, Metric] = {}
if self.optimization_config is not None:
optimization_config_metrics = self.optimization_config.metrics
return {**self._tracking_metrics, **optimization_config_metrics}

@property
def num_abandoned_arms(self) -> int:
"""How many arms attached to this experiment are abandoned."""
abandoned = set()
for trial in self.trials.values():
for x in trial.abandoned_arms:
abandoned.add(x)
return len(abandoned)
return len({aa for t in self.trials.values() for aa in t.abandoned_arms})

@property
def optimization_config(self) -> OptimizationConfig | None:
Expand Down Expand Up @@ -477,14 +476,6 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment:
del self._tracking_metrics[metric_name]
return self

@property
def metrics(self) -> dict[str, Metric]:
"""The metrics attached to the experiment."""
optimization_config_metrics: dict[str, Metric] = {}
if self.optimization_config is not None:
optimization_config_metrics = self.optimization_config.metrics
return {**self._tracking_metrics, **optimization_config_metrics}

def _metrics_by_class(
self, metrics: list[Metric] | None = None
) -> dict[type[Metric], list[Metric]]:
Expand All @@ -500,6 +491,7 @@ def _metrics_by_class(

def fetch_data_results(
self,
trial_indices: Iterable[int] | None = None,
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
Expand Down Expand Up @@ -528,43 +520,9 @@ def fetch_data_results(
"""

return self._lookup_or_fetch_trials_results(
trials=list(self.trials.values()),
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
**kwargs,
)

def fetch_trials_data_results(
self,
trial_indices: Iterable[int],
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
**kwargs: Any,
) -> dict[int, dict[str, MetricFetchResult]]:
"""Fetches data for specific trials on the experiment.

If a metric fetch fails, the Exception will be captured in the
MetricFetchResult along with a message.

NOTE: For metrics that are not available while trial is running, the data
may be retrieved from cache on the experiment. Data is cached on the experiment
via calls to `experiment.attach_data` and whether a given metric class is
available while trial is running is determined by the boolean returned from its
`is_available_while_running` class method.

Args:
trial_indices: Indices of trials, for which to fetch data.
metrics: If provided, fetch data for these metrics instead of the ones
defined on the experiment.
kwargs: keyword args to pass to underlying metrics' fetch data functions.

Returns:
A nested Dictionary from trial_index => metric_name => result
"""
return self._lookup_or_fetch_trials_results(
trials=self.get_trials_by_indices(trial_indices=trial_indices),
trials=self.get_trials_by_indices(trial_indices=trial_indices)
if trial_indices is not None
else list(self.trials.values()),
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
Expand All @@ -573,6 +531,7 @@ def fetch_trials_data_results(

def fetch_data(
self,
trial_indices: Iterable[int] | None = None,
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
Expand Down Expand Up @@ -600,63 +559,15 @@ def fetch_data(
Data for the experiment.
"""

results = self._lookup_or_fetch_trials_results(
trials=list(self.trials.values()),
results = self.fetch_data_results(
trial_indices=trial_indices,
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
**kwargs,
)

base_metric_cls = (
MapMetric if self.default_data_constructor == MapData else Metric
)

return base_metric_cls._unwrap_experiment_data_multi(
results=results,
)

def fetch_trials_data(
self,
trial_indices: Iterable[int],
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
**kwargs: Any,
) -> Data:
"""Fetches data for specific trials on the experiment.

NOTE: For metrics that are not available while trial is running, the data
may be retrieved from cache on the experiment. Data is cached on the experiment
via calls to `experiment.attach_data` and whetner a given metric class is
available while trial is running is determined by the boolean returned from its
`is_available_while_running` class method.

NOTE: This can be lossy (ex. a MapData could get implicitly cast to a Data and
lose rows) if Experiment.default_data_type is misconfigured!

Args:
trial_indices: Indices of trials, for which to fetch data.
metrics: If provided, fetch data for these metrics instead of the ones
defined on the experiment.
kwargs: Keyword args to pass to underlying metrics' fetch data functions.

Returns:
Data for the specific trials on the experiment.
"""

results = self._lookup_or_fetch_trials_results(
trials=self.get_trials_by_indices(trial_indices=trial_indices),
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
**kwargs,
)

base_metric_cls = (
MapMetric if self.default_data_constructor == MapData else Metric
)
return base_metric_cls._unwrap_experiment_data_multi(
use_map_data = self.default_data_constructor == MapData
return (MapMetric if use_map_data else Metric)._unwrap_experiment_data_multi(
results=results,
)

Expand Down
11 changes: 8 additions & 3 deletions ax/core/multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from collections.abc import Sequence
from typing import Any
from typing import Any, Iterable

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -254,6 +254,7 @@ def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment":
@copy_doc(Experiment.fetch_data)
def fetch_data(
self,
trial_indices: Iterable[int] | None = None,
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
Expand All @@ -265,11 +266,15 @@ def fetch_data(
return self.default_data_constructor.from_multiple_data(
[
(
trial.fetch_data(**kwargs, metrics=metrics)
trial.fetch_data(metrics=metrics, **kwargs)
if trial.status.expecting_data
else Data()
)
for trial in self.trials.values()
for trial in (
self.get_trials_by_indices(trial_indices=trial_indices)
if trial_indices is not None
else self.trials.values()
)
]
)

Expand Down
37 changes: 21 additions & 16 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from collections import OrderedDict
from enum import unique
from functools import reduce
from unittest.mock import MagicMock, patch

import pandas as pd
Expand Down Expand Up @@ -608,7 +609,9 @@ def test_NumArmsNoDeduplication(self) -> None:
arm = get_arm()
exp.new_batch_trial().add_arm(arm)
trial = exp.new_batch_trial().add_arm(arm)
self.assertEqual(exp.sum_trial_sizes, 2)
self.assertEqual(
reduce(lambda a, b: a + len(b.arms_by_name), exp._trials.values(), 0), 2
)
self.assertEqual(len(exp.arms_by_name), 1)
trial.mark_arm_abandoned(trial.arms[0].name)
self.assertEqual(exp.num_abandoned_arms, 1)
Expand Down Expand Up @@ -666,34 +669,34 @@ def test_FetchTrialsData(self) -> None:
batch_1 = exp.trials[1]
batch_0.mark_completed()
batch_1.mark_completed()
batch_0_data = exp.fetch_trials_data(trial_indices=[0])
batch_0_data = exp.fetch_data(trial_indices=[0])
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
)
batch_1_data = exp.fetch_trials_data(trial_indices=[1])
batch_1_data = exp.fetch_data(trial_indices=[1])
self.assertEqual(set(batch_1_data.df["trial_index"].values), {1})
self.assertEqual(
set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms}
)
self.assertEqual(
exp.fetch_trials_data(trial_indices=[0, 1]),
exp.fetch_data(trial_indices=[0, 1]),
Data.from_multiple_data([batch_0_data, batch_1_data]),
)

self.assertEqual(len(exp.data_by_trial[0]), 2)

with self.assertRaisesRegex(ValueError, ".* not associated .*"):
exp.fetch_trials_data(trial_indices=[2])
exp.fetch_data(trial_indices=[2])
# Try to fetch data when there are only metrics and no attached data.
exp.remove_tracking_metric(metric_name="b") # Remove implemented metric.
exp.add_tracking_metric(Metric(name="b")) # Add unimplemented metric.
self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).df), 5)
self.assertEqual(len(exp.fetch_data(trial_indices=[0]).df), 5)
# Try fetching attached data.
exp.attach_data(batch_0_data)
exp.attach_data(batch_1_data)
self.assertEqual(exp.fetch_trials_data(trial_indices=[0]), batch_0_data)
self.assertEqual(exp.fetch_trials_data(trial_indices=[1]), batch_1_data)
self.assertEqual(exp.fetch_data(trial_indices=[0]), batch_0_data)
self.assertEqual(exp.fetch_data(trial_indices=[1]), batch_1_data)
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
Expand Down Expand Up @@ -1444,38 +1447,40 @@ def test_FetchTrialsData(self) -> None:
batch_1 = exp.trials[1]
batch_0.mark_completed()
batch_1.mark_completed()
batch_0_data = exp.fetch_trials_data(trial_indices=[0])
batch_0_data = exp.fetch_data(trial_indices=[0])
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
)
batch_1_data = exp.fetch_trials_data(trial_indices=[1])
batch_1_data = exp.fetch_data(trial_indices=[1])
self.assertEqual(set(batch_1_data.df["trial_index"].values), {1})
self.assertEqual(
set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms}
)
self.assertEqual(
exp.fetch_trials_data(trial_indices=[0, 1]).df.shape[0],
exp.fetch_data(trial_indices=[0, 1]).df.shape[0],
len(exp.arms_by_name) * 2,
)

with self.assertRaisesRegex(ValueError, ".* not associated .*"):
exp.fetch_trials_data(trial_indices=[2])
exp.fetch_data(trial_indices=[2])
# Try to fetch data when there are only metrics and no attached data.
exp.remove_tracking_metric(metric_name="branin") # Remove implemented metric.
exp.add_tracking_metric(
BraninMetric(name="branin", param_names=["x1", "x2"])
) # Add unimplemented metric.
# pyre-fixme[16]: `Data` has no attribute `map_df`.
self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).map_df), 10)
self.assertEqual(
len(assert_is_instance(exp.fetch_data(trial_indices=[0]), MapData).map_df),
10,
)
# Try fetching attached data.
exp.attach_data(batch_0_data)
exp.attach_data(batch_1_data)
pd.testing.assert_frame_equal(
exp.fetch_trials_data(trial_indices=[0]).df, batch_0_data.df
exp.fetch_data(trial_indices=[0]).df, batch_0_data.df
)
pd.testing.assert_frame_equal(
exp.fetch_trials_data(trial_indices=[1]).df, batch_1_data.df
exp.fetch_data(trial_indices=[1]).df, batch_1_data.df
)
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
Expand Down
2 changes: 1 addition & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,7 +1991,7 @@ def _fetch_and_process_trials_data_results(
self.experiment, MultiTypeExperiment
).metrics_for_trial_type(trial_type=none_throws(self.trial_type))
kwargs["metrics"] = metrics
results = self.experiment.fetch_trials_data_results(
results = self.experiment.fetch_data_results(
trial_indices=trial_indices,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,7 +2599,7 @@ def helper_test_get_pareto_optimal_points_from_sobol_step(
# Check that the data in the frontier matches the observed data
# (it should be in the original, un-transformed space)
input_data = (
ax_client.experiment.fetch_trials_data([idx_of_frontier_point])
ax_client.experiment.fetch_data(trial_indices=[idx_of_frontier_point])
.df["mean"]
.values
)
Expand Down