Skip to content

Commit 7badae3

Browse files
authored
Use the baseline cache for the sampling evaluator (#550)
1 parent 761a8c8 commit 7badae3

2 files changed

Lines changed: 52 additions & 22 deletions

File tree

compiler_opt/es/blackbox_evaluator.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from compiler_opt.es import blackbox_optimizers
2929
from compiler_opt.distributed import buffered_scheduler
3030
from compiler_opt.rl import compilation_runner
31+
from compiler_opt import baseline_cache
3132

3233

3334
def _extract_results(futures: list[concurrent.futures.Future]) -> list[Any]:
@@ -51,6 +52,8 @@ def __init__(self, *, train_corpus: corpus.Corpus,
5152
estimator_type: blackbox_optimizers.EstimatorType):
5253
self._train_corpus = train_corpus
5354
self._estimator_type = estimator_type
55+
self._baseline_cache = baseline_cache.BaselineCache(
56+
get_key=lambda x: x.name)
5457

5558
@abc.abstractmethod
5659
def get_results(
@@ -73,12 +76,17 @@ def __init__(self,
7376
num_ir_repeats_within_worker: int = 1,
7477
**kwargs):
7578
super().__init__(**kwargs)
76-
self._samples: list[list[corpus.LoadedModuleSpec]] = []
7779
self._total_num_perturbations = total_num_perturbations
7880
self._num_ir_repeats_within_worker = num_ir_repeats_within_worker
79-
self._baselines: list[float | None] | None = None
81+
self._reset()
8082

81-
def _load_samples(self) -> None:
83+
def _reset(self):
84+
# TODO: this object is currently supposed to respect a state transition
85+
# and that makes it less maintainable than if not.
86+
self._samples = None
87+
self._baselines = None
88+
89+
def load_samples(self) -> None:
8290
"""Samples and loads modules if not already done.
8391
8492
Ensures self._samples contains the expected number of loaded samples.
@@ -89,6 +97,7 @@ def _load_samples(self) -> None:
8997
"""
9098
if self._samples:
9199
raise RuntimeError('Samples have already been loaded.')
100+
self._samples = []
92101
for _ in range(self._total_num_perturbations):
93102
samples = self._train_corpus.sample(self._num_ir_repeats_within_worker)
94103
loaded_samples = [
@@ -108,15 +117,15 @@ def _load_samples(self) -> None:
108117
if len(self._samples) != expected_count:
109118
raise RuntimeError('Some samples could not be loaded correctly.')
110119

111-
def _launch_compilation_workers(self,
112-
pool: FixedWorkerPool,
113-
perturbations: list[bytes] | None = None
114-
) -> list[concurrent.futures.Future]:
115-
if self._samples is None:
116-
raise RuntimeError('Loaded samples are not available.')
120+
def _launch_compilation_workers(
121+
self,
122+
pool: FixedWorkerPool,
123+
samples: list[list[corpus.LoadedModuleSpec]],
124+
perturbations: list[bytes] | None = None
125+
) -> list[concurrent.futures.Future]:
117126
if perturbations is None:
118-
perturbations = [None] * len(self._samples)
119-
compile_args = zip(perturbations, self._samples)
127+
perturbations = [None] * len(samples)
128+
compile_args = zip(perturbations, samples)
120129
_, futures = buffered_scheduler.schedule_on_worker_pool(
121130
action=lambda w, args: w.compile(policy=args[0], modules=args[1]),
122131
jobs=compile_args,
@@ -130,24 +139,43 @@ def _launch_compilation_workers(self,
130139
not_done, return_when=concurrent.futures.FIRST_COMPLETED)
131140
return futures
132141

142+
def ensure_baselines(self, pool):
143+
if self._samples is None:
144+
raise RuntimeError('Loaded samples are not available.')
145+
# flatten the samples.
146+
flat_samples = [item for sublist in self._samples for item in sublist]
147+
148+
def _get_scores(some_list):
149+
futures = self._launch_compilation_workers(pool, [[x] for x in some_list])
150+
return _extract_results(futures)
151+
152+
baselines = self._baseline_cache.get_score(flat_samples, _get_scores)
153+
154+
# TODO: the business of accummulating compilation results is now shared
155+
# with the worker.
156+
def sum_or_none(lst):
157+
return sum(lst) if all(x is not None for x in lst) else None
158+
159+
self._baselines = [
160+
sum_or_none(baselines[i:i + len(self._samples[i])])
161+
for i in range(len(self._samples))
162+
]
163+
133164
def get_results(
134165
self, pool: FixedWorkerPool,
135166
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
136-
# We should have _samples by now.
137167
if not self._samples:
138-
raise RuntimeError('Loaded samples are not available.')
139-
return self._launch_compilation_workers(pool, perturbations)
168+
self.load_samples()
169+
self.ensure_baselines(pool)
170+
return self._launch_compilation_workers(pool, self._samples, perturbations)
140171

141172
def set_baseline(self, pool: FixedWorkerPool) -> None:
142-
if self._baselines is not None:
143-
raise RuntimeError('The baseline has already been set.')
144-
self._load_samples()
145-
results_futures = self._launch_compilation_workers(pool)
146-
self._baselines = _extract_results(results_futures)
173+
pass
147174

148175
def get_rewards(
149176
self,
150177
results_futures: list[concurrent.futures.Future]) -> list[float | None]:
178+
# we need a pool to get the baselines, so we should have gotten them already
151179
if self._baselines is None:
152180
raise RuntimeError('The baseline has not been set.')
153181

@@ -165,6 +193,7 @@ def get_rewards(
165193
else:
166194
rewards.append(
167195
compilation_runner.calculate_reward(policy_result, baseline))
196+
self._reset()
168197
return rewards
169198

170199

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def test_sampling_set_baseline(self):
6161
train_corpus=test_corpus,
6262
estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD,
6363
total_num_perturbations=1)
64-
65-
evaluator.set_baseline(pool)
64+
evaluator.load_samples()
65+
evaluator.ensure_baselines(pool)
6666
# pylint: disable=protected-access
6767
self.assertAlmostEqual(evaluator._baselines, [10])
6868

@@ -90,7 +90,8 @@ def test_sampling_get_rewards_with_baseline(self):
9090
estimator_type=blackbox_optimizers.EstimatorType.FORWARD_FD,
9191
total_num_perturbations=2)
9292

93-
evaluator.set_baseline(pool)
93+
evaluator.load_samples()
94+
evaluator.ensure_baselines(pool)
9495

9596
f_policy1 = concurrent.futures.Future()
9697
f_policy1.set_result(1.5)

0 commit comments

Comments
 (0)