Skip to content

Commit fcdbebf

Browse files
Support batched bb traces in TraceBlackboxEvaluator
This patch makes TraceBlackboxEvaluator support batched bb traces. Essentially this means we can split all of the bb traces obtained from a single memtrace into a bunch of chunks. We can then evaluate only one of these chunks per perturbation. This adds some randomness to the training which can be good and also massively speeds up evaluation. Reviewers: svkeerthy, mtrofin Reviewed By: mtrofin Pull Request: #510
1 parent f4dc249 commit fcdbebf

2 files changed

Lines changed: 88 additions & 20 deletions

File tree

compiler_opt/es/blackbox_evaluator.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import abc
1717
import concurrent.futures
18+
import os
19+
import random
1820

1921
from absl import logging
2022
import gin
23+
import tensorflow as tf
2124

2225
from compiler_opt.distributed.worker import FixedWorkerPool
2326
from compiler_opt.rl import corpus
@@ -114,20 +117,33 @@ def __init__(self, train_corpus: corpus.Corpus,
114117
bb_trace_path: str, function_index_path: str):
115118
self._train_corpus = train_corpus
116119
self._estimator_type = estimator_type
117-
self._bb_trace_path = bb_trace_path
120+
self._bb_trace_paths = []
121+
if tf.io.gfile.isdir(bb_trace_path):
122+
self._bb_trace_paths.extend([
123+
os.path.join(bb_trace_path, bb_trace)
124+
for bb_trace in tf.io.gfile.listdir(bb_trace_path)
125+
])
126+
else:
127+
self._bb_trace_paths.append(bb_trace_path)
118128
self._function_index_path = function_index_path
119129

120-
self._baseline: float | None = None
130+
self._baselines: list[float] | None = None
121131

122132
def get_results(
123133
self, pool: FixedWorkerPool,
124134
perturbations: list[bytes]) -> list[concurrent.futures.Future]:
125-
job_args = [{
126-
'modules': self._train_corpus.module_specs,
127-
'function_index_path': self._function_index_path,
128-
'bb_trace_path': self._bb_trace_path,
129-
'policy_as_bytes': perturbation,
130-
} for perturbation in perturbations]
135+
job_args = []
136+
self._current_baselines = []
137+
for perturbation in perturbations:
138+
bb_trace_path_index = random.randrange(len(self._bb_trace_paths))
139+
bb_trace_path = self._bb_trace_paths[bb_trace_path_index]
140+
self._current_baselines.append(self._baselines[bb_trace_path_index])
141+
job_args.append({
142+
'modules': self._train_corpus.module_specs,
143+
'function_index_path': self._function_index_path,
144+
'bb_trace_path': bb_trace_path,
145+
'policy_as_bytes': perturbation
146+
})
131147

132148
_, futures = buffered_scheduler.schedule_on_worker_pool(
133149
action=lambda w, args: w.compile_corpus_and_evaluate(**args),
@@ -138,15 +154,15 @@ def get_results(
138154
return futures
139155

140156
def set_baseline(self, pool: FixedWorkerPool) -> None:
141-
if self._baseline is not None:
157+
if self._baselines is not None:
142158
raise RuntimeError('The baseline has already been set.')
143159

144160
job_args = [{
145161
'modules': self._train_corpus.module_specs,
146162
'function_index_path': self._function_index_path,
147-
'bb_trace_path': self._bb_trace_path,
163+
'bb_trace_path': bb_trace_path,
148164
'policy_as_bytes': None,
149-
}]
165+
} for bb_trace_path in self._bb_trace_paths]
150166

151167
_, futures = buffered_scheduler.schedule_on_worker_pool(
152168
action=lambda w, args: w.compile_corpus_and_evaluate(**args),
@@ -155,21 +171,22 @@ def set_baseline(self, pool: FixedWorkerPool) -> None:
155171

156172
concurrent.futures.wait(
157173
futures, return_when=concurrent.futures.ALL_COMPLETED)
158-
if len(futures) != 1:
159-
raise ValueError('Expected to have one result for setting the baseline,'
160-
f' got {len(futures)}')
174+
if len(futures) != len(self._bb_trace_paths):
175+
raise ValueError(
176+
f'Expected to have {len(self._bb_trace_paths)} results for setting,'
177+
f'the baseline, got {len(futures)}.')
161178

162-
self._baseline = futures[0].result()
179+
self._baselines = [future.result() for future in futures]
163180

164181
def get_rewards(
165182
self, results: list[concurrent.futures.Future]) -> list[float | None]:
166183
rewards = []
167184

168-
for result in results:
185+
for result, baseline in zip(results, self._current_baselines):
169186
if result.exception() is not None:
170187
raise result.exception()
171188

172189
rewards.append(
173-
compilation_runner.calculate_reward(result.result(), self._baseline))
190+
compilation_runner.calculate_reward(result.result(), baseline))
174191

175192
return rewards

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def test_trace_get_results(self):
6767
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
6868
test_corpus, blackbox_optimizers.EstimatorType.FORWARD_FD,
6969
'fake_bb_trace_path', 'fake_function_index_path')
70+
# pylint: disable=protected-access
71+
evaluator._baselines = [1]
72+
# pylint: enable=protected-access
7073
results = evaluator.get_results(pool, perturbations)
7174
self.assertSequenceAlmostEqual([result.result() for result in results],
7275
[1.0, 1.0, 1.0])
@@ -85,7 +88,9 @@ def test_trace_set_baseline(self):
8588
'fake_bb_trace_path', 'fake_function_index_path')
8689
evaluator.set_baseline(pool)
8790
# pylint: disable=protected-access
88-
self.assertAlmostEqual(evaluator._baseline, 10)
91+
self.assertLen(evaluator._baselines, 1)
92+
self.assertAlmostEqual(evaluator._baselines[0], 10)
93+
# pylint: enable=protected-access
8994

9095
def test_trace_get_rewards(self):
9196
f1 = concurrent.futures.Future()
@@ -101,10 +106,56 @@ def test_trace_get_rewards(self):
101106
'fake_bb_trace_path', 'fake_function_index_path')
102107

103108
# pylint: disable=protected-access
104-
evaluator._baseline = 2
109+
evaluator._current_baselines = [2, 3]
110+
# pylint: enable=protected-access
105111
rewards = evaluator.get_rewards(results)
106112

107113
# Only check for two decimal places as the reward calculation uses a
108114
# reasonably large delta (0.01) when calculating the difference to
109115
# prevent division by zero.
110-
self.assertSequenceAlmostEqual(rewards, [0, -0.5], 2)
116+
self.assertSequenceAlmostEqual(rewards, [0, 0], 2)
117+
118+
def test_trace_multiple_get_results(self):
119+
with local_worker_manager.LocalWorkerPoolManager(
120+
blackbox_test_utils.ESTraceWorker,
121+
count=3,
122+
worker_args=(),
123+
worker_kwargs={}) as pool:
124+
perturbations = [b'00', b'01', b'10']
125+
test_corpus = corpus.create_corpus_for_testing(
126+
location=self.create_tempdir().full_path,
127+
elements=[corpus.ModuleSpec(name='name1', size=1)])
128+
bb_trace_dir = self.create_tempdir()
129+
bb_trace_dir.create_file('bb_trace1.pb')
130+
bb_trace_dir.create_file('bb_trace2.pb')
131+
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
132+
test_corpus, blackbox_optimizers.EstimatorType.FORWARD_FD,
133+
bb_trace_dir.full_path, 'fake_function_index_path')
134+
# pylint: disable=protected-access
135+
evaluator._baselines = [1, 2]
136+
# pylint: enable=protected-access
137+
results = evaluator.get_results(pool, perturbations)
138+
self.assertSequenceAlmostEqual([result.result() for result in results],
139+
[1.0, 1.0, 1.0])
140+
141+
def test_trace_multiple_set_baseline(self):
142+
with local_worker_manager.LocalWorkerPoolManager(
143+
blackbox_test_utils.ESTraceWorker,
144+
count=1,
145+
worker_args=(),
146+
worker_kwargs={}) as pool:
147+
test_corpus = corpus.create_corpus_for_testing(
148+
location=self.create_tempdir().full_path,
149+
elements=[corpus.ModuleSpec(name='name1', size=1)])
150+
bb_trace_dir = self.create_tempdir()
151+
bb_trace_dir.create_file('bb_trace1.pb')
152+
bb_trace_dir.create_file('bb_trace2.pb')
153+
evaluator = blackbox_evaluator.TraceBlackboxEvaluator(
154+
test_corpus, blackbox_optimizers.EstimatorType.FORWARD_FD,
155+
bb_trace_dir.full_path, 'fake_function_index_path')
156+
evaluator.set_baseline(pool)
157+
# pylint: disable=protected-access
158+
self.assertLen(evaluator._baselines, 2)
159+
self.assertAlmostEqual(evaluator._baselines[0], 10)
160+
self.assertAlmostEqual(evaluator._baselines[1], 10)
161+
# pylint: enable=protected-access

0 commit comments

Comments
 (0)