From e3f91616f6c62ea0cb3f2d1ddbc37311d8518620 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Fri, 7 Feb 2025 15:11:37 -0800 Subject: [PATCH] Reap rarely used/unused methods and properties Summary: There were very few instances of usage of the methods this removes, so this diff changes the callsites and reaps the methods. Differential Revision: D69313269 --- ax/core/experiment.py | 123 ++++------------------------- ax/core/multi_type_experiment.py | 11 ++- ax/core/tests/test_experiment.py | 37 +++++---- ax/service/scheduler.py | 2 +- ax/service/tests/test_ax_client.py | 2 +- 5 files changed, 48 insertions(+), 127 deletions(-) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 64339ddfd7e..2f551271f46 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -14,7 +14,7 @@ from collections import defaultdict, OrderedDict from collections.abc import Hashable, Iterable, Mapping from datetime import datetime -from functools import partial, reduce +from functools import partial from typing import Any, cast @@ -355,18 +355,17 @@ def arms_by_signature_for_deduplication(self) -> dict[str, Arm]: return arms_dict @property - def sum_trial_sizes(self) -> int: - """Sum of numbers of arms attached to each trial in this experiment.""" - return reduce(lambda a, b: a + len(b.arms_by_name), self._trials.values(), 0) + def metrics(self) -> dict[str, Metric]: + """The metrics attached to the experiment.""" + optimization_config_metrics: dict[str, Metric] = {} + if self.optimization_config is not None: + optimization_config_metrics = self.optimization_config.metrics + return {**self._tracking_metrics, **optimization_config_metrics} @property def num_abandoned_arms(self) -> int: """How many arms attached to this experiment are abandoned.""" - abandoned = set() - for trial in self.trials.values(): - for x in trial.abandoned_arms: - abandoned.add(x) - return len(abandoned) + return len({aa for t in self.trials.values() for aa in t.abandoned_arms}) @property def optimization_config(self) -> OptimizationConfig | None: @@ -495,14 +494,6 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment: del self._tracking_metrics[metric_name] return self - @property - def metrics(self) -> dict[str, Metric]: - """The metrics attached to the experiment.""" - optimization_config_metrics: dict[str, Metric] = {} - if self.optimization_config is not None: - optimization_config_metrics = self.optimization_config.metrics - return {**self._tracking_metrics, **optimization_config_metrics} - def _metrics_by_class( self, metrics: list[Metric] | None = None ) -> dict[type[Metric], list[Metric]]: @@ -518,6 +509,7 @@ def _metrics_by_class( def fetch_data_results( self, + trial_indices: Iterable[int] | None = None, metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, @@ -546,43 +538,9 @@ def fetch_data_results( """ return self._lookup_or_fetch_trials_results( - trials=list(self.trials.values()), - metrics=metrics, - combine_with_last_data=combine_with_last_data, - overwrite_existing_data=overwrite_existing_data, - **kwargs, - ) - - def fetch_trials_data_results( - self, - trial_indices: Iterable[int], - metrics: list[Metric] | None = None, - combine_with_last_data: bool = False, - overwrite_existing_data: bool = False, - **kwargs: Any, - ) -> dict[int, dict[str, MetricFetchResult]]: - """Fetches data for specific trials on the experiment. - - If a metric fetch fails, the Exception will be captured in the - MetricFetchResult along with a message. - - NOTE: For metrics that are not available while trial is running, the data - may be retrieved from cache on the experiment. Data is cached on the experiment - via calls to `experiment.attach_data` and whether a given metric class is - available while trial is running is determined by the boolean returned from its - `is_available_while_running` class method. - - Args: - trial_indices: Indices of trials, for which to fetch data. - metrics: If provided, fetch data for these metrics instead of the ones - defined on the experiment. - kwargs: keyword args to pass to underlying metrics' fetch data functions. - - Returns: - A nested Dictionary from trial_index => metric_name => result - """ - return self._lookup_or_fetch_trials_results( - trials=self.get_trials_by_indices(trial_indices=trial_indices), + trials=self.get_trials_by_indices(trial_indices=trial_indices) + if trial_indices is not None + else list(self.trials.values()), metrics=metrics, combine_with_last_data=combine_with_last_data, overwrite_existing_data=overwrite_existing_data, @@ -591,6 +549,7 @@ def fetch_trials_data_results( def fetch_data( self, + trial_indices: Iterable[int] | None = None, metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, @@ -618,63 +577,15 @@ def fetch_data( Data for the experiment. """ - results = self._lookup_or_fetch_trials_results( - trials=list(self.trials.values()), + results = self.fetch_data_results( + trial_indices=trial_indices, metrics=metrics, combine_with_last_data=combine_with_last_data, overwrite_existing_data=overwrite_existing_data, **kwargs, ) - - base_metric_cls = ( - MapMetric if self.default_data_constructor == MapData else Metric - ) - - return base_metric_cls._unwrap_experiment_data_multi( - results=results, - ) - - def fetch_trials_data( - self, - trial_indices: Iterable[int], - metrics: list[Metric] | None = None, - combine_with_last_data: bool = False, - overwrite_existing_data: bool = False, - **kwargs: Any, - ) -> Data: - """Fetches data for specific trials on the experiment. - - NOTE: For metrics that are not available while trial is running, the data - may be retrieved from cache on the experiment. Data is cached on the experiment - via calls to `experiment.attach_data` and whetner a given metric class is - available while trial is running is determined by the boolean returned from its - `is_available_while_running` class method. - - NOTE: This can be lossy (ex. a MapData could get implicitly cast to a Data and - lose rows) if Experiment.default_data_type is misconfigured! - - Args: - trial_indices: Indices of trials, for which to fetch data. - metrics: If provided, fetch data for these metrics instead of the ones - defined on the experiment. - kwargs: Keyword args to pass to underlying metrics' fetch data functions. - - Returns: - Data for the specific trials on the experiment. - """ - - results = self._lookup_or_fetch_trials_results( - trials=self.get_trials_by_indices(trial_indices=trial_indices), - metrics=metrics, - combine_with_last_data=combine_with_last_data, - overwrite_existing_data=overwrite_existing_data, - **kwargs, - ) - - base_metric_cls = ( - MapMetric if self.default_data_constructor == MapData else Metric - ) - return base_metric_cls._unwrap_experiment_data_multi( + use_map_data = self.default_data_constructor == MapData + return (MapMetric if use_map_data else Metric)._unwrap_experiment_data_multi( results=results, ) diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index 297ec626234..cbd744446bb 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -8,7 +8,7 @@ import logging from collections.abc import Sequence -from typing import Any +from typing import Any, Iterable from ax.core.arm import Arm from ax.core.base_trial import BaseTrial, TrialStatus @@ -256,6 +256,7 @@ def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment": @copy_doc(Experiment.fetch_data) def fetch_data( self, + trial_indices: Iterable[int] | None = None, metrics: list[Metric] | None = None, combine_with_last_data: bool = False, overwrite_existing_data: bool = False, @@ -267,11 +268,15 @@ def fetch_data( return self.default_data_constructor.from_multiple_data( [ ( - trial.fetch_data(**kwargs, metrics=metrics) + trial.fetch_data(metrics=metrics, **kwargs) if trial.status.expecting_data else Data() ) - for trial in self.trials.values() + for trial in ( + self.get_trials_by_indices(trial_indices=trial_indices) + if trial_indices is not None + else self.trials.values() + ) ] ) diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 2d70bd5ad79..343669c1255 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -9,6 +9,7 @@ import logging from collections import OrderedDict from enum import unique +from functools import reduce from unittest.mock import MagicMock, patch import pandas as pd @@ -609,7 +610,9 @@ def test_NumArmsNoDeduplication(self) -> None: arm = get_arm() exp.new_batch_trial().add_arm(arm) trial = exp.new_batch_trial().add_arm(arm) - self.assertEqual(exp.sum_trial_sizes, 2) + self.assertEqual( + reduce(lambda a, b: a + len(b.arms_by_name), exp._trials.values(), 0), 2 + ) self.assertEqual(len(exp.arms_by_name), 1) trial.mark_arm_abandoned(trial.arms[0].name) self.assertEqual(exp.num_abandoned_arms, 1) @@ -667,34 +670,34 @@ def test_FetchTrialsData(self) -> None: batch_1 = exp.trials[1] batch_0.mark_completed() batch_1.mark_completed() - batch_0_data = exp.fetch_trials_data(trial_indices=[0]) + batch_0_data = exp.fetch_data(trial_indices=[0]) self.assertEqual(set(batch_0_data.df["trial_index"].values), {0}) self.assertEqual( set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms} ) - batch_1_data = exp.fetch_trials_data(trial_indices=[1]) + batch_1_data = exp.fetch_data(trial_indices=[1]) self.assertEqual(set(batch_1_data.df["trial_index"].values), {1}) self.assertEqual( set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms} ) self.assertEqual( - exp.fetch_trials_data(trial_indices=[0, 1]), + exp.fetch_data(trial_indices=[0, 1]), Data.from_multiple_data([batch_0_data, batch_1_data]), ) self.assertEqual(len(exp.data_by_trial[0]), 2) with self.assertRaisesRegex(ValueError, ".* not associated .*"): - exp.fetch_trials_data(trial_indices=[2]) + exp.fetch_data(trial_indices=[2]) # Try to fetch data when there are only metrics and no attached data. exp.remove_tracking_metric(metric_name="b") # Remove implemented metric. exp.add_tracking_metric(Metric(name="b")) # Add unimplemented metric. - self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).df), 5) + self.assertEqual(len(exp.fetch_data(trial_indices=[0]).df), 5) # Try fetching attached data. exp.attach_data(batch_0_data) exp.attach_data(batch_1_data) - self.assertEqual(exp.fetch_trials_data(trial_indices=[0]), batch_0_data) - self.assertEqual(exp.fetch_trials_data(trial_indices=[1]), batch_1_data) + self.assertEqual(exp.fetch_data(trial_indices=[0]), batch_0_data) + self.assertEqual(exp.fetch_data(trial_indices=[1]), batch_1_data) self.assertEqual(set(batch_0_data.df["trial_index"].values), {0}) self.assertEqual( set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms} @@ -1445,38 +1448,40 @@ def test_FetchTrialsData(self) -> None: batch_1 = exp.trials[1] batch_0.mark_completed() batch_1.mark_completed() - batch_0_data = exp.fetch_trials_data(trial_indices=[0]) + batch_0_data = exp.fetch_data(trial_indices=[0]) self.assertEqual(set(batch_0_data.df["trial_index"].values), {0}) self.assertEqual( set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms} ) - batch_1_data = exp.fetch_trials_data(trial_indices=[1]) + batch_1_data = exp.fetch_data(trial_indices=[1]) self.assertEqual(set(batch_1_data.df["trial_index"].values), {1}) self.assertEqual( set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms} ) self.assertEqual( - exp.fetch_trials_data(trial_indices=[0, 1]).df.shape[0], + exp.fetch_data(trial_indices=[0, 1]).df.shape[0], len(exp.arms_by_name) * 2, ) with self.assertRaisesRegex(ValueError, ".* not associated .*"): - exp.fetch_trials_data(trial_indices=[2]) + exp.fetch_data(trial_indices=[2]) # Try to fetch data when there are only metrics and no attached data. exp.remove_tracking_metric(metric_name="branin") # Remove implemented metric. exp.add_tracking_metric( BraninMetric(name="branin", param_names=["x1", "x2"]) ) # Add unimplemented metric. - # pyre-fixme[16]: `Data` has no attribute `map_df`. - self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).map_df), 10) + self.assertEqual( + len(assert_is_instance(exp.fetch_data(trial_indices=[0]), MapData).map_df), + 10, + ) # Try fetching attached data. exp.attach_data(batch_0_data) exp.attach_data(batch_1_data) pd.testing.assert_frame_equal( - exp.fetch_trials_data(trial_indices=[0]).df, batch_0_data.df + exp.fetch_data(trial_indices=[0]).df, batch_0_data.df ) pd.testing.assert_frame_equal( - exp.fetch_trials_data(trial_indices=[1]).df, batch_1_data.df + exp.fetch_data(trial_indices=[1]).df, batch_1_data.df ) self.assertEqual(set(batch_0_data.df["trial_index"].values), {0}) self.assertEqual( diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 7e81c06ca46..7f0ed2bb768 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -1992,7 +1992,7 @@ def _fetch_and_process_trials_data_results( self.experiment, MultiTypeExperiment ).metrics_for_trial_type(trial_type=none_throws(self.trial_type)) kwargs["metrics"] = metrics - results = self.experiment.fetch_trials_data_results( + results = self.experiment.fetch_data_results( trial_indices=trial_indices, **kwargs, ) diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 349367d1b33..6d4f9c70b31 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -2599,7 +2599,7 @@ def helper_test_get_pareto_optimal_points_from_sobol_step( # Check that the data in the frontier matches the observed data # (it should be in the original, un-transformed space) input_data = ( - ax_client.experiment.fetch_trials_data([idx_of_frontier_point]) + ax_client.experiment.fetch_data(trial_indices=[idx_of_frontier_point]) .df["mean"] .values )