Skip to content

Commit

Permalink
Reap rarely used/unused methods and properties
Browse files Browse the repository at this point in the history
Summary: There were very few instances of usage of the methods this removes, so this diff changes the callsites and reaps the methods.

Differential Revision: D69313269
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Feb 7, 2025
1 parent 6b8287b commit e3f9161
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 127 deletions.
123 changes: 17 additions & 106 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import defaultdict, OrderedDict
from collections.abc import Hashable, Iterable, Mapping
from datetime import datetime
from functools import partial, reduce
from functools import partial

from typing import Any, cast

Expand Down Expand Up @@ -355,18 +355,17 @@ def arms_by_signature_for_deduplication(self) -> dict[str, Arm]:
return arms_dict

@property
def sum_trial_sizes(self) -> int:
"""Sum of numbers of arms attached to each trial in this experiment."""
return reduce(lambda a, b: a + len(b.arms_by_name), self._trials.values(), 0)
def metrics(self) -> dict[str, Metric]:
"""The metrics attached to the experiment."""
optimization_config_metrics: dict[str, Metric] = {}
if self.optimization_config is not None:
optimization_config_metrics = self.optimization_config.metrics
return {**self._tracking_metrics, **optimization_config_metrics}

@property
def num_abandoned_arms(self) -> int:
"""How many arms attached to this experiment are abandoned."""
abandoned = set()
for trial in self.trials.values():
for x in trial.abandoned_arms:
abandoned.add(x)
return len(abandoned)
return len({aa for t in self.trials.values() for aa in t.abandoned_arms})

@property
def optimization_config(self) -> OptimizationConfig | None:
Expand Down Expand Up @@ -495,14 +494,6 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment:
del self._tracking_metrics[metric_name]
return self

@property
def metrics(self) -> dict[str, Metric]:
"""The metrics attached to the experiment."""
optimization_config_metrics: dict[str, Metric] = {}
if self.optimization_config is not None:
optimization_config_metrics = self.optimization_config.metrics
return {**self._tracking_metrics, **optimization_config_metrics}

def _metrics_by_class(
self, metrics: list[Metric] | None = None
) -> dict[type[Metric], list[Metric]]:
Expand All @@ -518,6 +509,7 @@ def _metrics_by_class(

def fetch_data_results(
self,
trial_indices: Iterable[int] | None = None,
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
Expand Down Expand Up @@ -546,43 +538,9 @@ def fetch_data_results(
"""

return self._lookup_or_fetch_trials_results(
trials=list(self.trials.values()),
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
**kwargs,
)

def fetch_trials_data_results(
self,
trial_indices: Iterable[int],
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
**kwargs: Any,
) -> dict[int, dict[str, MetricFetchResult]]:
"""Fetches data for specific trials on the experiment.
If a metric fetch fails, the Exception will be captured in the
MetricFetchResult along with a message.
NOTE: For metrics that are not available while trial is running, the data
may be retrieved from cache on the experiment. Data is cached on the experiment
via calls to `experiment.attach_data` and whether a given metric class is
available while trial is running is determined by the boolean returned from its
`is_available_while_running` class method.
Args:
trial_indices: Indices of trials, for which to fetch data.
metrics: If provided, fetch data for these metrics instead of the ones
defined on the experiment.
kwargs: keyword args to pass to underlying metrics' fetch data functions.
Returns:
A nested Dictionary from trial_index => metric_name => result
"""
return self._lookup_or_fetch_trials_results(
trials=self.get_trials_by_indices(trial_indices=trial_indices),
trials=self.get_trials_by_indices(trial_indices=trial_indices)
if trial_indices is not None
else list(self.trials.values()),
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
Expand All @@ -591,6 +549,7 @@ def fetch_trials_data_results(

def fetch_data(
self,
trial_indices: Iterable[int] | None = None,
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
Expand Down Expand Up @@ -618,63 +577,15 @@ def fetch_data(
Data for the experiment.
"""

results = self._lookup_or_fetch_trials_results(
trials=list(self.trials.values()),
results = self.fetch_data_results(
trial_indices=trial_indices,
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
**kwargs,
)

base_metric_cls = (
MapMetric if self.default_data_constructor == MapData else Metric
)

return base_metric_cls._unwrap_experiment_data_multi(
results=results,
)

def fetch_trials_data(
self,
trial_indices: Iterable[int],
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
**kwargs: Any,
) -> Data:
"""Fetches data for specific trials on the experiment.
NOTE: For metrics that are not available while trial is running, the data
may be retrieved from cache on the experiment. Data is cached on the experiment
via calls to `experiment.attach_data` and whetner a given metric class is
available while trial is running is determined by the boolean returned from its
`is_available_while_running` class method.
NOTE: This can be lossy (ex. a MapData could get implicitly cast to a Data and
lose rows) if Experiment.default_data_type is misconfigured!
Args:
trial_indices: Indices of trials, for which to fetch data.
metrics: If provided, fetch data for these metrics instead of the ones
defined on the experiment.
kwargs: Keyword args to pass to underlying metrics' fetch data functions.
Returns:
Data for the specific trials on the experiment.
"""

results = self._lookup_or_fetch_trials_results(
trials=self.get_trials_by_indices(trial_indices=trial_indices),
metrics=metrics,
combine_with_last_data=combine_with_last_data,
overwrite_existing_data=overwrite_existing_data,
**kwargs,
)

base_metric_cls = (
MapMetric if self.default_data_constructor == MapData else Metric
)
return base_metric_cls._unwrap_experiment_data_multi(
use_map_data = self.default_data_constructor == MapData
return (MapMetric if use_map_data else Metric)._unwrap_experiment_data_multi(
results=results,
)

Expand Down
11 changes: 8 additions & 3 deletions ax/core/multi_type_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from collections.abc import Sequence
from typing import Any
from typing import Any, Iterable

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -256,6 +256,7 @@ def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment":
@copy_doc(Experiment.fetch_data)
def fetch_data(
self,
trial_indices: Iterable[int] | None = None,
metrics: list[Metric] | None = None,
combine_with_last_data: bool = False,
overwrite_existing_data: bool = False,
Expand All @@ -267,11 +268,15 @@ def fetch_data(
return self.default_data_constructor.from_multiple_data(
[
(
trial.fetch_data(**kwargs, metrics=metrics)
trial.fetch_data(metrics=metrics, **kwargs)
if trial.status.expecting_data
else Data()
)
for trial in self.trials.values()
for trial in (
self.get_trials_by_indices(trial_indices=trial_indices)
if trial_indices is not None
else self.trials.values()
)
]
)

Expand Down
37 changes: 21 additions & 16 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from collections import OrderedDict
from enum import unique
from functools import reduce
from unittest.mock import MagicMock, patch

import pandas as pd
Expand Down Expand Up @@ -609,7 +610,9 @@ def test_NumArmsNoDeduplication(self) -> None:
arm = get_arm()
exp.new_batch_trial().add_arm(arm)
trial = exp.new_batch_trial().add_arm(arm)
self.assertEqual(exp.sum_trial_sizes, 2)
self.assertEqual(
reduce(lambda a, b: a + len(b.arms_by_name), exp._trials.values(), 0), 2
)
self.assertEqual(len(exp.arms_by_name), 1)
trial.mark_arm_abandoned(trial.arms[0].name)
self.assertEqual(exp.num_abandoned_arms, 1)
Expand Down Expand Up @@ -667,34 +670,34 @@ def test_FetchTrialsData(self) -> None:
batch_1 = exp.trials[1]
batch_0.mark_completed()
batch_1.mark_completed()
batch_0_data = exp.fetch_trials_data(trial_indices=[0])
batch_0_data = exp.fetch_data(trial_indices=[0])
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
)
batch_1_data = exp.fetch_trials_data(trial_indices=[1])
batch_1_data = exp.fetch_data(trial_indices=[1])
self.assertEqual(set(batch_1_data.df["trial_index"].values), {1})
self.assertEqual(
set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms}
)
self.assertEqual(
exp.fetch_trials_data(trial_indices=[0, 1]),
exp.fetch_data(trial_indices=[0, 1]),
Data.from_multiple_data([batch_0_data, batch_1_data]),
)

self.assertEqual(len(exp.data_by_trial[0]), 2)

with self.assertRaisesRegex(ValueError, ".* not associated .*"):
exp.fetch_trials_data(trial_indices=[2])
exp.fetch_data(trial_indices=[2])
# Try to fetch data when there are only metrics and no attached data.
exp.remove_tracking_metric(metric_name="b") # Remove implemented metric.
exp.add_tracking_metric(Metric(name="b")) # Add unimplemented metric.
self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).df), 5)
self.assertEqual(len(exp.fetch_data(trial_indices=[0]).df), 5)
# Try fetching attached data.
exp.attach_data(batch_0_data)
exp.attach_data(batch_1_data)
self.assertEqual(exp.fetch_trials_data(trial_indices=[0]), batch_0_data)
self.assertEqual(exp.fetch_trials_data(trial_indices=[1]), batch_1_data)
self.assertEqual(exp.fetch_data(trial_indices=[0]), batch_0_data)
self.assertEqual(exp.fetch_data(trial_indices=[1]), batch_1_data)
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
Expand Down Expand Up @@ -1445,38 +1448,40 @@ def test_FetchTrialsData(self) -> None:
batch_1 = exp.trials[1]
batch_0.mark_completed()
batch_1.mark_completed()
batch_0_data = exp.fetch_trials_data(trial_indices=[0])
batch_0_data = exp.fetch_data(trial_indices=[0])
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
)
batch_1_data = exp.fetch_trials_data(trial_indices=[1])
batch_1_data = exp.fetch_data(trial_indices=[1])
self.assertEqual(set(batch_1_data.df["trial_index"].values), {1})
self.assertEqual(
set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms}
)
self.assertEqual(
exp.fetch_trials_data(trial_indices=[0, 1]).df.shape[0],
exp.fetch_data(trial_indices=[0, 1]).df.shape[0],
len(exp.arms_by_name) * 2,
)

with self.assertRaisesRegex(ValueError, ".* not associated .*"):
exp.fetch_trials_data(trial_indices=[2])
exp.fetch_data(trial_indices=[2])
# Try to fetch data when there are only metrics and no attached data.
exp.remove_tracking_metric(metric_name="branin") # Remove implemented metric.
exp.add_tracking_metric(
BraninMetric(name="branin", param_names=["x1", "x2"])
) # Add unimplemented metric.
# pyre-fixme[16]: `Data` has no attribute `map_df`.
self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).map_df), 10)
self.assertEqual(
len(assert_is_instance(exp.fetch_data(trial_indices=[0]), MapData).map_df),
10,
)
# Try fetching attached data.
exp.attach_data(batch_0_data)
exp.attach_data(batch_1_data)
pd.testing.assert_frame_equal(
exp.fetch_trials_data(trial_indices=[0]).df, batch_0_data.df
exp.fetch_data(trial_indices=[0]).df, batch_0_data.df
)
pd.testing.assert_frame_equal(
exp.fetch_trials_data(trial_indices=[1]).df, batch_1_data.df
exp.fetch_data(trial_indices=[1]).df, batch_1_data.df
)
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
self.assertEqual(
Expand Down
2 changes: 1 addition & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,7 +1992,7 @@ def _fetch_and_process_trials_data_results(
self.experiment, MultiTypeExperiment
).metrics_for_trial_type(trial_type=none_throws(self.trial_type))
kwargs["metrics"] = metrics
results = self.experiment.fetch_trials_data_results(
results = self.experiment.fetch_data_results(
trial_indices=trial_indices,
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,7 +2599,7 @@ def helper_test_get_pareto_optimal_points_from_sobol_step(
# Check that the data in the frontier matches the observed data
# (it should be in the original, un-transformed space)
input_data = (
ax_client.experiment.fetch_trials_data([idx_of_frontier_point])
ax_client.experiment.fetch_data(trial_indices=[idx_of_frontier_point])
.df["mean"]
.values
)
Expand Down

0 comments on commit e3f9161

Please sign in to comment.