Skip to content

Commit 7a7de6a

Browse files
liyuying0000copybara-github
authored andcommitted
Move benchmark_filter/workload_filter functions to benchmark.py
PiperOrigin-RevId: 728366933 Change-Id: I12e37ce3c2540bc6c1a4ea01b371dbd5250ad2f5
1 parent 4878d65 commit 7a7de6a

File tree

4 files changed

+206
-205
lines changed

4 files changed

+206
-205
lines changed

fleetbench/parallel/benchmark.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ def GetSubBenchmarks(benchmark_path: str, workload: str = "") -> list[str]:
6060
6161
If 'workload' is specified, only sub-benchmarks with the given workload are
6262
returned.
63+
64+
Args:
65+
benchmark_path: Path to the benchmark binary.
66+
workload: The workload to filter for. If empty, all sub-benchmarks are
67+
returned.
68+
69+
Returns:
70+
A list of sub-benchmark names that match the workload filter.
6371
"""
6472
cmd = [benchmark_path, "--benchmark_list_tests"]
6573

@@ -84,11 +92,11 @@ def GetWorkloads(benchmark_path: str):
8492
benchmarks = GetSubBenchmarks(benchmark_path, "all")
8593
workload_pattern = r"BM_(?P<workload>[^_]+)"
8694

87-
def extract_workload(benchmark):
95+
def _ExtractWorkload(benchmark):
8896
match = re.search(workload_pattern, benchmark)
8997
return match.group("workload") if match else None
9098

91-
return list(set(filter(None, map(extract_workload, benchmarks))))
99+
return list(set(filter(None, map(_ExtractWorkload, benchmarks))))
92100

93101

94102
class Benchmark:
@@ -116,3 +124,101 @@ def Path(self):
116124

117125
def __str__(self):
118126
return self.Name()
127+
128+
129+
def _CreateBenchmarks(bm_target: str, names: list[str]) -> dict[str, Benchmark]:
130+
"""Creates benchmark dictionary with the benchmark name as the key."""
131+
benchmarks = {}
132+
for name in names:
133+
benchmark = Benchmark(bm_target, name)
134+
benchmarks[benchmark.Name()] = benchmark
135+
return benchmarks
136+
137+
138+
def _CreateMatchingBenchmarks(
139+
bm_target: str, bm_filter: str, bm_candidates: list[str]
140+
) -> dict[str, Benchmark]:
141+
"""Creates benchmarks that match the given filter."""
142+
matching_bm_names = [name for name in bm_candidates if bm_filter in name]
143+
if not matching_bm_names:
144+
raise ValueError(f"Can't find benchmarks matching {bm_filter}.")
145+
return _CreateBenchmarks(bm_target, matching_bm_names)
146+
147+
148+
def GetDefaultBenchmarks(
149+
benchmark_target: str, benchmark_filters: list[str]
150+
) -> dict[str, Benchmark]:
151+
"""Get a list of benchmarks from the default target.
152+
153+
Filtering options:
154+
- Empty list: Returns all default benchmarks.
155+
- Keyword list: Returns benchmarks from the default list matching the provided
156+
keyword (e.g., "Cold Hot").
157+
158+
Args:
159+
benchmark_target: Path to the benchmark binary.
160+
benchmark_filters: List of filters to apply to the benchmarks to run.
161+
162+
Returns:
163+
A map of benchmark names to Benchmark objects.
164+
"""
165+
benchmarks = {}
166+
sub_benchmarks = GetSubBenchmarks(benchmark_target)
167+
168+
# Gets default benchmark sets
169+
if not benchmark_filters:
170+
return _CreateBenchmarks(benchmark_target, sub_benchmarks)
171+
172+
# Gets benchmark sets from filters
173+
for bm_filter in benchmark_filters:
174+
benchmarks.update(
175+
_CreateMatchingBenchmarks(benchmark_target, bm_filter, sub_benchmarks)
176+
)
177+
return benchmarks
178+
179+
180+
def GetWorkloadBenchmarks(
181+
benchmark_target: str, workload_filters: list[str]
182+
) -> dict[str, Benchmark]:
183+
"""Get a list of benchmarks from the given workload that match the filters.
184+
185+
Filtering options:
186+
- Workload name + keyword(s): Returns benchmarks associated with the
187+
specified workload, further filtered by keywords (e.g.,
188+
"libc,Memcpy,Memcmp").
189+
- Workload name + "all": Returns all benchmarks associated with the
190+
specified workload (e.g., "proto,all").
191+
Args:
192+
benchmark_target: Path to the benchmark binary.
193+
workload_filters: List of filters to apply to the benchmarks to run.
194+
195+
Returns:
196+
A map of benchmark names to Benchmark objects.
197+
"""
198+
benchmarks = {}
199+
200+
# Get all unique workloads
201+
workloads = GetWorkloads(benchmark_target)
202+
203+
def _GetWorkloadAndFilter(bm_filter: str) -> tuple[str, list[str]]:
204+
parts = bm_filter.split(",")
205+
if parts[0].upper() not in workloads:
206+
raise ValueError(f"Workload {parts[0]} not supported in Fleetbench.")
207+
return parts[0], parts[1:]
208+
209+
for workload_filter in workload_filters:
210+
workload, bm_filters = _GetWorkloadAndFilter(workload_filter)
211+
workload_bms = GetSubBenchmarks(benchmark_target, workload)
212+
if bm_filters == ["all"]:
213+
benchmarks.update(
214+
_CreateMatchingBenchmarks(
215+
benchmark_target, workload.upper(), workload_bms
216+
)
217+
)
218+
else:
219+
for bm_filter in bm_filters:
220+
benchmarks.update(
221+
_CreateMatchingBenchmarks(benchmark_target, bm_filter, workload_bms)
222+
)
223+
224+
return benchmarks

fleetbench/parallel/benchmark_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,97 @@ def testGetSubBenchmarksWorkloadWithUnmatchedBM(self, mock_run):
147147
env=mock.ANY,
148148
)
149149

150+
@mock.patch.object(benchmark, "GetSubBenchmarks", autospec=True)
151+
@flagsaver.flagsaver(
152+
benchmark_dir=absltest.get_default_test_tmpdir(),
153+
)
154+
def test_getbenchmark_without_filter(self, mock_get_subbenchmarks):
155+
mock_get_subbenchmarks.return_value = ["BM_Test1", "BM_Test2"]
156+
self.create_tempfile(
157+
os.path.join(absltest.get_default_test_tmpdir(), "fake_bench")
158+
)
159+
160+
benchmarks = benchmark.GetDefaultBenchmarks("fake_bench", [])
161+
self.assertLen(benchmarks, 2)
162+
self.assertCountEqual(
163+
benchmarks.keys(),
164+
["fake_bench (BM_Test1)", "fake_bench (BM_Test2)"],
165+
)
166+
167+
@mock.patch.object(benchmark, "GetSubBenchmarks", autospec=True)
168+
@flagsaver.flagsaver(
169+
benchmark_dir=absltest.get_default_test_tmpdir(),
170+
)
171+
def test_getbenchmark_with_filter_partial_match(self, mock_get_subbenchmarks):
172+
mock_get_subbenchmarks.return_value = ["BM_Test1", "BM_Test2"]
173+
self.create_tempfile(
174+
os.path.join(absltest.get_default_test_tmpdir(), "fake_bench")
175+
)
176+
benchmarks = benchmark.GetDefaultBenchmarks("fake_bench", ["Test1"])
177+
self.assertLen(benchmarks, 1)
178+
self.assertCountEqual(
179+
benchmarks.keys(),
180+
["fake_bench (BM_Test1)"],
181+
)
182+
183+
@mock.patch.object(benchmark, "GetWorkloads", autospec=True)
184+
@mock.patch.object(benchmark, "GetSubBenchmarks", autospec=True)
185+
@flagsaver.flagsaver(
186+
benchmark_dir=absltest.get_default_test_tmpdir(),
187+
)
188+
def test_getworkloadbenchmark_subset(
189+
self, mock_get_subbenchmarks, mock_get_workloads
190+
):
191+
mock_get_workloads.return_value = ["PROTO", "CORD"]
192+
mock_get_subbenchmarks.side_effect = [
193+
["BM_PROTO_Test1", "BM_PROTO_Test2", "BM_PROTO_Test3"],
194+
["BM_CORD_Test1"],
195+
]
196+
self.create_tempfile(
197+
os.path.join(absltest.get_default_test_tmpdir(), "fake_bench")
198+
)
199+
200+
benchmarks = benchmark.GetWorkloadBenchmarks(
201+
"fake_bench", ["proto,1,2", "cord,all"]
202+
)
203+
self.assertLen(benchmarks, 3)
204+
self.assertCountEqual(
205+
benchmarks.keys(),
206+
[
207+
"fake_bench (BM_PROTO_Test1)",
208+
"fake_bench (BM_PROTO_Test2)",
209+
"fake_bench (BM_CORD_Test1)",
210+
],
211+
)
212+
213+
@mock.patch.object(benchmark, "GetWorkloads", autospec=True)
214+
@mock.patch.object(benchmark, "GetSubBenchmarks", autospec=True)
215+
@flagsaver.flagsaver(
216+
benchmark_dir=absltest.get_default_test_tmpdir(),
217+
)
218+
def test_getworkloadbenchmark_all(
219+
self, mock_get_subbenchmarks, mock_get_workloads
220+
):
221+
mock_get_workloads.return_value = ["PROTO"]
222+
mock_get_subbenchmarks.return_value = [
223+
"BM_PROTO_Test1",
224+
"BM_PROTO_Test2",
225+
"BM_PROTO_Test3",
226+
]
227+
self.create_tempfile(
228+
os.path.join(absltest.get_default_test_tmpdir(), "fake_bench")
229+
)
230+
benchmarks = benchmark.GetWorkloadBenchmarks("fake_bench", ["proto,all"])
231+
self.assertLen(benchmarks, 3)
232+
self.assertCountEqual(
233+
benchmarks.keys(),
234+
[
235+
"fake_bench (BM_PROTO_Test1)",
236+
"fake_bench (BM_PROTO_Test2)",
237+
"fake_bench (BM_PROTO_Test3)",
238+
],
239+
)
240+
150241

151242
if __name__ == "__main__":
152243
absltest.main()

fleetbench/parallel/parallel_bench_lib.py

Lines changed: 7 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from fleetbench.parallel import benchmark as bm
2828
from fleetbench.parallel import cpu
29-
from fleetbench.parallel import result
3029
from fleetbench.parallel import run
3130
from fleetbench.parallel import worker
3231

@@ -41,7 +40,7 @@ def ParseBenchmarkWeights(
4140
filter should be in ALL CAPS to ensure case-insensitive matching.
4241
4342
Args:
44-
benchmark_list: A list of strings to parse.
43+
benchmark_weights_list: A list of strings to parse.
4544
4645
Returns:
4746
A dictionary of {capitalized string: float} or None if the list is empty.
@@ -56,114 +55,14 @@ def ParseBenchmarkWeights(
5655
benchmark_weights[key.upper()] = float(value_str)
5756
except ValueError:
5857
logging.warning(
59-
f"Invalid benchmark string: %s. The format should be"
60-
f" <benchmark_name|benchmark_filter>:<weight>. Skipping...",
58+
"Invalid benchmark string: %s. The format should be"
59+
" <benchmark_name|benchmark_filter>:<weight>. Skipping...",
6160
weights,
6261
)
6362

6463
return benchmark_weights
6564

6665

67-
def _CreateBenchmarks(
68-
bm_target: str, names: list[str]
69-
) -> dict[str, bm.Benchmark]:
70-
"""Creates benchmark dictionary with the benchmark name as the key."""
71-
benchmarks = {}
72-
for name in names:
73-
benchmark = bm.Benchmark(bm_target, name)
74-
benchmarks[benchmark.Name()] = benchmark
75-
return benchmarks
76-
77-
78-
def _CreateMatchingBenchmarks(
79-
bm_target: str, bm_filter: str, bm_candidates: list[str]
80-
) -> dict[str, bm.Benchmark]:
81-
"""Creates benchmarks that match the given filter."""
82-
matching_bm_names = [name for name in bm_candidates if bm_filter in name]
83-
if not matching_bm_names:
84-
raise ValueError(f"Can't find benchmarks matching {bm_filter}.")
85-
return _CreateBenchmarks(bm_target, matching_bm_names)
86-
87-
88-
def _GetDefaultBenchmarks(
89-
benchmark_target: str, benchmark_filters: list[str]
90-
) -> dict[str, bm.Benchmark]:
91-
"""Get a list of benchmarks from the default target.
92-
93-
Filtering options:
94-
- Empty list: Returns all default benchmarks.
95-
- Keyword list: Returns benchmarks from the default list matching the provided
96-
keyword (e.g., "Cold Hot").
97-
98-
Args:
99-
benchmark_target: Path to the benchmark binary.
100-
benchmark_filters: List of filters to apply to the benchmarks to run.
101-
102-
Returns:
103-
A map of benchmark names to Benchmark objects.
104-
"""
105-
benchmarks = {}
106-
sub_benchmarks = bm.GetSubBenchmarks(benchmark_target)
107-
108-
# Gets default benchmark sets
109-
if not benchmark_filters:
110-
return _CreateBenchmarks(benchmark_target, sub_benchmarks)
111-
112-
# Gets benchmark sets from filters
113-
for bm_filter in benchmark_filters:
114-
benchmarks.update(
115-
_CreateMatchingBenchmarks(benchmark_target, bm_filter, sub_benchmarks)
116-
)
117-
return benchmarks
118-
119-
120-
def _GetWorkloadBenchmarks(
121-
benchmark_target: str, workload_filters: list[str]
122-
) -> dict[str, bm.Benchmark]:
123-
"""Get a list of benchmarks from the given workload that match the filters.
124-
125-
Filtering options:
126-
- Workload name + keyword(s): Returns benchmarks associated with the
127-
specified workload, further filtered by keywords (e.g.,
128-
"libc,Memcpy,Memcmp").
129-
- Workload name + "all": Returns all benchmarks associated with the
130-
specified workload (e.g., "proto,all").
131-
Args:
132-
benchmark_target: Path to the benchmark binary.
133-
workload_filters: List of filters to apply to the benchmarks to run.
134-
135-
Returns:
136-
A map of benchmark names to Benchmark objects.
137-
"""
138-
benchmarks = {}
139-
140-
# Get all unique workloads
141-
workloads = bm.GetWorkloads(benchmark_target)
142-
143-
def _GetWorkloadAndFilter(bm_filter: str) -> tuple[str, list[str]]:
144-
parts = bm_filter.split(",")
145-
if parts[0].upper() not in workloads:
146-
raise ValueError(f"Workload {parts[0]} not supported in Fleetbench.")
147-
return parts[0], parts[1:]
148-
149-
for workload_filter in workload_filters:
150-
workload, bm_filters = _GetWorkloadAndFilter(workload_filter)
151-
workload_bms = bm.GetSubBenchmarks(benchmark_target, workload)
152-
if bm_filters == ["all"]:
153-
benchmarks.update(
154-
_CreateMatchingBenchmarks(
155-
benchmark_target, workload.upper(), workload_bms
156-
)
157-
)
158-
else:
159-
for bm_filter in bm_filters:
160-
benchmarks.update(
161-
_CreateMatchingBenchmarks(benchmark_target, bm_filter, workload_bms)
162-
)
163-
164-
return benchmarks
165-
166-
16766
def _SetExtraBenchmarkFlags(
16867
benchmark_perf_counters: str,
16968
benchmark_repetitions: int,
@@ -254,11 +153,11 @@ def _PreRun(
254153
logging.info("Initializing benchmarks and worker threads...")
255154

256155
if workload_filters:
257-
self.benchmarks = _GetWorkloadBenchmarks(
156+
self.benchmarks = bm.GetWorkloadBenchmarks(
258157
benchmark_target, workload_filters
259158
)
260159
else:
261-
self.benchmarks = _GetDefaultBenchmarks(
160+
self.benchmarks = bm.GetDefaultBenchmarks(
262161
benchmark_target, benchmark_filters
263162
)
264163

@@ -271,8 +170,8 @@ def _PreRun(
271170
for benchmark in self.benchmarks.values():
272171
benchmark.AddCommandFlags(benchmark_flags)
273172

274-
# Initialize the runtimes with a fake wall time of 1. This causes all benchmarks
275-
# to be equally likely at first.
173+
# Initialize the runtimes with a fake wall time of 1. This causes all
174+
# benchmarks to be equally likely at first.
276175
self.runtimes = {
277176
benchmark: [
278177
BenchmarkMetrics(

0 commit comments

Comments
 (0)