Skip to content

Commit e3f9161

Browse files
Lena Kashtelyanfacebook-github-bot
authored andcommitted
Reap rarely used/unused methods and properties
Summary: There were very few instances of usage of the methods this removes, so this diff changes the callsites and reaps the methods. Differential Revision: D69313269
1 parent 6b8287b commit e3f9161

File tree

5 files changed

+48
-127
lines changed

5 files changed

+48
-127
lines changed

ax/core/experiment.py

Lines changed: 17 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections import defaultdict, OrderedDict
1515
from collections.abc import Hashable, Iterable, Mapping
1616
from datetime import datetime
17-
from functools import partial, reduce
17+
from functools import partial
1818

1919
from typing import Any, cast
2020

@@ -355,18 +355,17 @@ def arms_by_signature_for_deduplication(self) -> dict[str, Arm]:
355355
return arms_dict
356356

357357
@property
358-
def sum_trial_sizes(self) -> int:
359-
"""Sum of numbers of arms attached to each trial in this experiment."""
360-
return reduce(lambda a, b: a + len(b.arms_by_name), self._trials.values(), 0)
358+
def metrics(self) -> dict[str, Metric]:
359+
"""The metrics attached to the experiment."""
360+
optimization_config_metrics: dict[str, Metric] = {}
361+
if self.optimization_config is not None:
362+
optimization_config_metrics = self.optimization_config.metrics
363+
return {**self._tracking_metrics, **optimization_config_metrics}
361364

362365
@property
363366
def num_abandoned_arms(self) -> int:
364367
"""How many arms attached to this experiment are abandoned."""
365-
abandoned = set()
366-
for trial in self.trials.values():
367-
for x in trial.abandoned_arms:
368-
abandoned.add(x)
369-
return len(abandoned)
368+
return len({aa for t in self.trials.values() for aa in t.abandoned_arms})
370369

371370
@property
372371
def optimization_config(self) -> OptimizationConfig | None:
@@ -495,14 +494,6 @@ def remove_tracking_metric(self, metric_name: str) -> Experiment:
495494
del self._tracking_metrics[metric_name]
496495
return self
497496

498-
@property
499-
def metrics(self) -> dict[str, Metric]:
500-
"""The metrics attached to the experiment."""
501-
optimization_config_metrics: dict[str, Metric] = {}
502-
if self.optimization_config is not None:
503-
optimization_config_metrics = self.optimization_config.metrics
504-
return {**self._tracking_metrics, **optimization_config_metrics}
505-
506497
def _metrics_by_class(
507498
self, metrics: list[Metric] | None = None
508499
) -> dict[type[Metric], list[Metric]]:
@@ -518,6 +509,7 @@ def _metrics_by_class(
518509

519510
def fetch_data_results(
520511
self,
512+
trial_indices: Iterable[int] | None = None,
521513
metrics: list[Metric] | None = None,
522514
combine_with_last_data: bool = False,
523515
overwrite_existing_data: bool = False,
@@ -546,43 +538,9 @@ def fetch_data_results(
546538
"""
547539

548540
return self._lookup_or_fetch_trials_results(
549-
trials=list(self.trials.values()),
550-
metrics=metrics,
551-
combine_with_last_data=combine_with_last_data,
552-
overwrite_existing_data=overwrite_existing_data,
553-
**kwargs,
554-
)
555-
556-
def fetch_trials_data_results(
557-
self,
558-
trial_indices: Iterable[int],
559-
metrics: list[Metric] | None = None,
560-
combine_with_last_data: bool = False,
561-
overwrite_existing_data: bool = False,
562-
**kwargs: Any,
563-
) -> dict[int, dict[str, MetricFetchResult]]:
564-
"""Fetches data for specific trials on the experiment.
565-
566-
If a metric fetch fails, the Exception will be captured in the
567-
MetricFetchResult along with a message.
568-
569-
NOTE: For metrics that are not available while trial is running, the data
570-
may be retrieved from cache on the experiment. Data is cached on the experiment
571-
via calls to `experiment.attach_data` and whether a given metric class is
572-
available while trial is running is determined by the boolean returned from its
573-
`is_available_while_running` class method.
574-
575-
Args:
576-
trial_indices: Indices of trials, for which to fetch data.
577-
metrics: If provided, fetch data for these metrics instead of the ones
578-
defined on the experiment.
579-
kwargs: keyword args to pass to underlying metrics' fetch data functions.
580-
581-
Returns:
582-
A nested Dictionary from trial_index => metric_name => result
583-
"""
584-
return self._lookup_or_fetch_trials_results(
585-
trials=self.get_trials_by_indices(trial_indices=trial_indices),
541+
trials=self.get_trials_by_indices(trial_indices=trial_indices)
542+
if trial_indices is not None
543+
else list(self.trials.values()),
586544
metrics=metrics,
587545
combine_with_last_data=combine_with_last_data,
588546
overwrite_existing_data=overwrite_existing_data,
@@ -591,6 +549,7 @@ def fetch_trials_data_results(
591549

592550
def fetch_data(
593551
self,
552+
trial_indices: Iterable[int] | None = None,
594553
metrics: list[Metric] | None = None,
595554
combine_with_last_data: bool = False,
596555
overwrite_existing_data: bool = False,
@@ -618,63 +577,15 @@ def fetch_data(
618577
Data for the experiment.
619578
"""
620579

621-
results = self._lookup_or_fetch_trials_results(
622-
trials=list(self.trials.values()),
580+
results = self.fetch_data_results(
581+
trial_indices=trial_indices,
623582
metrics=metrics,
624583
combine_with_last_data=combine_with_last_data,
625584
overwrite_existing_data=overwrite_existing_data,
626585
**kwargs,
627586
)
628-
629-
base_metric_cls = (
630-
MapMetric if self.default_data_constructor == MapData else Metric
631-
)
632-
633-
return base_metric_cls._unwrap_experiment_data_multi(
634-
results=results,
635-
)
636-
637-
def fetch_trials_data(
638-
self,
639-
trial_indices: Iterable[int],
640-
metrics: list[Metric] | None = None,
641-
combine_with_last_data: bool = False,
642-
overwrite_existing_data: bool = False,
643-
**kwargs: Any,
644-
) -> Data:
645-
"""Fetches data for specific trials on the experiment.
646-
647-
NOTE: For metrics that are not available while trial is running, the data
648-
may be retrieved from cache on the experiment. Data is cached on the experiment
649-
via calls to `experiment.attach_data` and whetner a given metric class is
650-
available while trial is running is determined by the boolean returned from its
651-
`is_available_while_running` class method.
652-
653-
NOTE: This can be lossy (ex. a MapData could get implicitly cast to a Data and
654-
lose rows) if Experiment.default_data_type is misconfigured!
655-
656-
Args:
657-
trial_indices: Indices of trials, for which to fetch data.
658-
metrics: If provided, fetch data for these metrics instead of the ones
659-
defined on the experiment.
660-
kwargs: Keyword args to pass to underlying metrics' fetch data functions.
661-
662-
Returns:
663-
Data for the specific trials on the experiment.
664-
"""
665-
666-
results = self._lookup_or_fetch_trials_results(
667-
trials=self.get_trials_by_indices(trial_indices=trial_indices),
668-
metrics=metrics,
669-
combine_with_last_data=combine_with_last_data,
670-
overwrite_existing_data=overwrite_existing_data,
671-
**kwargs,
672-
)
673-
674-
base_metric_cls = (
675-
MapMetric if self.default_data_constructor == MapData else Metric
676-
)
677-
return base_metric_cls._unwrap_experiment_data_multi(
587+
use_map_data = self.default_data_constructor == MapData
588+
return (MapMetric if use_map_data else Metric)._unwrap_experiment_data_multi(
678589
results=results,
679590
)
680591

ax/core/multi_type_experiment.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from collections.abc import Sequence
11-
from typing import Any
11+
from typing import Any, Iterable
1212

1313
from ax.core.arm import Arm
1414
from ax.core.base_trial import BaseTrial, TrialStatus
@@ -256,6 +256,7 @@ def remove_tracking_metric(self, metric_name: str) -> "MultiTypeExperiment":
256256
@copy_doc(Experiment.fetch_data)
257257
def fetch_data(
258258
self,
259+
trial_indices: Iterable[int] | None = None,
259260
metrics: list[Metric] | None = None,
260261
combine_with_last_data: bool = False,
261262
overwrite_existing_data: bool = False,
@@ -267,11 +268,15 @@ def fetch_data(
267268
return self.default_data_constructor.from_multiple_data(
268269
[
269270
(
270-
trial.fetch_data(**kwargs, metrics=metrics)
271+
trial.fetch_data(metrics=metrics, **kwargs)
271272
if trial.status.expecting_data
272273
else Data()
273274
)
274-
for trial in self.trials.values()
275+
for trial in (
276+
self.get_trials_by_indices(trial_indices=trial_indices)
277+
if trial_indices is not None
278+
else self.trials.values()
279+
)
275280
]
276281
)
277282

ax/core/tests/test_experiment.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
from collections import OrderedDict
1111
from enum import unique
12+
from functools import reduce
1213
from unittest.mock import MagicMock, patch
1314

1415
import pandas as pd
@@ -609,7 +610,9 @@ def test_NumArmsNoDeduplication(self) -> None:
609610
arm = get_arm()
610611
exp.new_batch_trial().add_arm(arm)
611612
trial = exp.new_batch_trial().add_arm(arm)
612-
self.assertEqual(exp.sum_trial_sizes, 2)
613+
self.assertEqual(
614+
reduce(lambda a, b: a + len(b.arms_by_name), exp._trials.values(), 0), 2
615+
)
613616
self.assertEqual(len(exp.arms_by_name), 1)
614617
trial.mark_arm_abandoned(trial.arms[0].name)
615618
self.assertEqual(exp.num_abandoned_arms, 1)
@@ -667,34 +670,34 @@ def test_FetchTrialsData(self) -> None:
667670
batch_1 = exp.trials[1]
668671
batch_0.mark_completed()
669672
batch_1.mark_completed()
670-
batch_0_data = exp.fetch_trials_data(trial_indices=[0])
673+
batch_0_data = exp.fetch_data(trial_indices=[0])
671674
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
672675
self.assertEqual(
673676
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
674677
)
675-
batch_1_data = exp.fetch_trials_data(trial_indices=[1])
678+
batch_1_data = exp.fetch_data(trial_indices=[1])
676679
self.assertEqual(set(batch_1_data.df["trial_index"].values), {1})
677680
self.assertEqual(
678681
set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms}
679682
)
680683
self.assertEqual(
681-
exp.fetch_trials_data(trial_indices=[0, 1]),
684+
exp.fetch_data(trial_indices=[0, 1]),
682685
Data.from_multiple_data([batch_0_data, batch_1_data]),
683686
)
684687

685688
self.assertEqual(len(exp.data_by_trial[0]), 2)
686689

687690
with self.assertRaisesRegex(ValueError, ".* not associated .*"):
688-
exp.fetch_trials_data(trial_indices=[2])
691+
exp.fetch_data(trial_indices=[2])
689692
# Try to fetch data when there are only metrics and no attached data.
690693
exp.remove_tracking_metric(metric_name="b") # Remove implemented metric.
691694
exp.add_tracking_metric(Metric(name="b")) # Add unimplemented metric.
692-
self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).df), 5)
695+
self.assertEqual(len(exp.fetch_data(trial_indices=[0]).df), 5)
693696
# Try fetching attached data.
694697
exp.attach_data(batch_0_data)
695698
exp.attach_data(batch_1_data)
696-
self.assertEqual(exp.fetch_trials_data(trial_indices=[0]), batch_0_data)
697-
self.assertEqual(exp.fetch_trials_data(trial_indices=[1]), batch_1_data)
699+
self.assertEqual(exp.fetch_data(trial_indices=[0]), batch_0_data)
700+
self.assertEqual(exp.fetch_data(trial_indices=[1]), batch_1_data)
698701
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
699702
self.assertEqual(
700703
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
@@ -1445,38 +1448,40 @@ def test_FetchTrialsData(self) -> None:
14451448
batch_1 = exp.trials[1]
14461449
batch_0.mark_completed()
14471450
batch_1.mark_completed()
1448-
batch_0_data = exp.fetch_trials_data(trial_indices=[0])
1451+
batch_0_data = exp.fetch_data(trial_indices=[0])
14491452
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
14501453
self.assertEqual(
14511454
set(batch_0_data.df["arm_name"].values), {a.name for a in batch_0.arms}
14521455
)
1453-
batch_1_data = exp.fetch_trials_data(trial_indices=[1])
1456+
batch_1_data = exp.fetch_data(trial_indices=[1])
14541457
self.assertEqual(set(batch_1_data.df["trial_index"].values), {1})
14551458
self.assertEqual(
14561459
set(batch_1_data.df["arm_name"].values), {a.name for a in batch_1.arms}
14571460
)
14581461
self.assertEqual(
1459-
exp.fetch_trials_data(trial_indices=[0, 1]).df.shape[0],
1462+
exp.fetch_data(trial_indices=[0, 1]).df.shape[0],
14601463
len(exp.arms_by_name) * 2,
14611464
)
14621465

14631466
with self.assertRaisesRegex(ValueError, ".* not associated .*"):
1464-
exp.fetch_trials_data(trial_indices=[2])
1467+
exp.fetch_data(trial_indices=[2])
14651468
# Try to fetch data when there are only metrics and no attached data.
14661469
exp.remove_tracking_metric(metric_name="branin") # Remove implemented metric.
14671470
exp.add_tracking_metric(
14681471
BraninMetric(name="branin", param_names=["x1", "x2"])
14691472
) # Add unimplemented metric.
1470-
# pyre-fixme[16]: `Data` has no attribute `map_df`.
1471-
self.assertEqual(len(exp.fetch_trials_data(trial_indices=[0]).map_df), 10)
1473+
self.assertEqual(
1474+
len(assert_is_instance(exp.fetch_data(trial_indices=[0]), MapData).map_df),
1475+
10,
1476+
)
14721477
# Try fetching attached data.
14731478
exp.attach_data(batch_0_data)
14741479
exp.attach_data(batch_1_data)
14751480
pd.testing.assert_frame_equal(
1476-
exp.fetch_trials_data(trial_indices=[0]).df, batch_0_data.df
1481+
exp.fetch_data(trial_indices=[0]).df, batch_0_data.df
14771482
)
14781483
pd.testing.assert_frame_equal(
1479-
exp.fetch_trials_data(trial_indices=[1]).df, batch_1_data.df
1484+
exp.fetch_data(trial_indices=[1]).df, batch_1_data.df
14801485
)
14811486
self.assertEqual(set(batch_0_data.df["trial_index"].values), {0})
14821487
self.assertEqual(

ax/service/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1992,7 +1992,7 @@ def _fetch_and_process_trials_data_results(
19921992
self.experiment, MultiTypeExperiment
19931993
).metrics_for_trial_type(trial_type=none_throws(self.trial_type))
19941994
kwargs["metrics"] = metrics
1995-
results = self.experiment.fetch_trials_data_results(
1995+
results = self.experiment.fetch_data_results(
19961996
trial_indices=trial_indices,
19971997
**kwargs,
19981998
)

ax/service/tests/test_ax_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2599,7 +2599,7 @@ def helper_test_get_pareto_optimal_points_from_sobol_step(
25992599
# Check that the data in the frontier matches the observed data
26002600
# (it should be in the original, un-transformed space)
26012601
input_data = (
2602-
ax_client.experiment.fetch_trials_data([idx_of_frontier_point])
2602+
ax_client.experiment.fetch_data(trial_indices=[idx_of_frontier_point])
26032603
.df["mean"]
26042604
.values
26052605
)

0 commit comments

Comments
 (0)