Skip to content

Commit 7e53385

Browse files
liyuying0000copybara-github
authored andcommitted
Add support for --workload_filter in parallel benchmark.
PiperOrigin-RevId: 663807393 Change-Id: I3d83d9584cadaf32f890e9e7b70256b9bc2e5890
1 parent 3f82b13 commit 7e53385

File tree

5 files changed

+277
-42
lines changed

5 files changed

+277
-42
lines changed

fleetbench/parallel/benchmark.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Represent a Fleetbench benchmark."""
1616

1717
import os
18+
import re
1819
import subprocess
1920

2021
from absl import flags
@@ -44,12 +45,36 @@ def _FindBenchmarkPath(benchmark: str) -> str:
4445
raise FileNotFoundError(f"Benchmark not found: {benchmark}")
4546

4647

47-
def GetSubBenchmarks(benchmark_path: str):
48+
def GetSubBenchmarks(benchmark_path: str, workload: str = "") -> list[str]:
49+
"""Retrieves a list of sub-benchmarks from a benchmark executable.
50+
51+
If 'workload' is specified, only sub-benchmarks with the given workload are
52+
returned.
53+
"""
4854
cmd = [benchmark_path, "--benchmark_list_tests"]
55+
56+
if workload == "all":
57+
cmd += ["--benchmark_filter=all"]
58+
elif workload:
59+
cmd += [
60+
f"--benchmark_filter=BM_{workload.upper()}",
61+
]
4962
p = subprocess.run(cmd, capture_output=True, text=True, check=True)
5063
return p.stdout.split("\n")[:-1]
5164

5265

66+
def GetWorkloads(benchmark_path: str):
67+
"""Retrieves a list of unique workloads from a benchmark executable."""
68+
benchmarks = GetSubBenchmarks(benchmark_path, "all")
69+
workload_pattern = r"BM_(?P<workload>[^_]+)"
70+
71+
def extract_workload(benchmark):
72+
match = re.search(workload_pattern, benchmark)
73+
return match.group("workload") if match else None
74+
75+
return list(set(filter(None, map(extract_workload, benchmarks))))
76+
77+
5378
class Benchmark:
5479
"""Represents a benchmark binary and filter."""
5580

fleetbench/parallel/benchmark_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,62 @@ def testCommandLine(self):
9090
],
9191
)
9292

93+
@mock.patch.object(subprocess, "run", autospec=True)
94+
def testGetWorkloads(self, mock_run):
95+
mock_run.return_value = subprocess.CompletedProcess(
96+
args=[],
97+
returncode=0,
98+
stdout=(
99+
"BM_LIBC_Test1\nBM_LIBC_Test2\nBM_PROTO_Test1\nBM_TCMALLOC_Test1\n"
100+
),
101+
stderr="",
102+
)
103+
self.assertCountEqual(
104+
benchmark.GetWorkloads("/path/to/benchmark"),
105+
["LIBC", "PROTO", "TCMALLOC"],
106+
)
107+
108+
@mock.patch.object(subprocess, "run", autospec=True)
109+
def testGetSubBenchmarksWorkload(self, mock_run):
110+
mock_run.return_value = subprocess.CompletedProcess(
111+
args=[],
112+
returncode=0,
113+
stdout="BM_PROTO_Arena\nBM_PROTO_NoArena\n",
114+
stderr="",
115+
)
116+
self.assertEqual(
117+
benchmark.GetSubBenchmarks("/path/to/benchmark", "proto"),
118+
["BM_PROTO_Arena", "BM_PROTO_NoArena"],
119+
)
120+
121+
@mock.patch.object(subprocess, "run", autospec=True)
122+
def testGetSubBenchmarksWorkloadWithUnmatchedBM(self, mock_run):
123+
# Simulate the full list of benchmarks
124+
full_benchmarks = ["BM_PROTO_Arena", "BM_PROTO_NoArena", "BM_CORD_Fleet"]
125+
126+
# Simulate the subprocess output
127+
mock_run.return_value = subprocess.CompletedProcess(
128+
args=[],
129+
returncode=0,
130+
stdout="\n".join(full_benchmarks),
131+
stderr="",
132+
)
133+
134+
sub_benchmarks = benchmark.GetSubBenchmarks("/path/to/benchmark", "proto")
135+
136+
# Assert the expected behavior
137+
self.assertEqual(sub_benchmarks, ["BM_PROTO_Arena", "BM_PROTO_NoArena"])
138+
mock_run.assert_called_once_with(
139+
[
140+
"/path/to/benchmark",
141+
"--benchmark_list_tests",
142+
"--benchmark_filter=BM_PROTO",
143+
],
144+
capture_output=True,
145+
text=True,
146+
check=True,
147+
)
148+
93149

94150
if __name__ == "__main__":
95151
absltest.main()

fleetbench/parallel/parallel_bench.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,32 @@
4949
)
5050

5151
_BENCHMARK_FILTER = flags.DEFINE_multi_string(
52-
"benchmark_filter", [], "Specifies subset of benchmarks to run."
52+
"benchmark_filter",
53+
[],
54+
"""Specifies subset of benchmarks to run.
55+
56+
Filtering options:
57+
- Empty list: Selects all default benchmarks.
58+
- Keyword list: Selects benchmarks from the default list matching any \n
59+
provided keyword, one keyword per filter \n
60+
(e.g., "--benchmark_filter=Cold --benchmark_filter=Hot").""",
61+
)
62+
63+
64+
_WORKLOAD_FILTER = flags.DEFINE_multi_string(
65+
"workload_filter",
66+
[],
67+
"""Selects benchmarks associated with specified workloads. This will \n
68+
overwrite the `--benchmark_filter` flag.
69+
70+
Filtering options:
71+
- Workload name + keyword(s): Selects benchmarks associated with the \n
72+
specified workload, further filtered by keywords
73+
(e.g., "--workload_filter=libc,Memcpy,Memcmp").
74+
- Workload name + "all": Selects all benchmarks associated with the \n
75+
specified workload
76+
(e.g., "--workload_filter=proto,all")
77+
""",
5378
)
5479

5580
_BENCHMARK_PERF_COUNTERS = flags.DEFINE_string(
@@ -112,6 +137,7 @@ def main(argv: Sequence[str]) -> None:
112137
results = bench.Run(
113138
benchmark_target=_BENCHMARK_TARGET.value,
114139
benchmark_filter=_BENCHMARK_FILTER.value,
140+
workload_filter=_WORKLOAD_FILTER.value,
115141
benchmark_perf_counters=_BENCHMARK_PERF_COUNTERS.value,
116142
benchmark_repetitions=_BENCHMARK_REPETITIONS.value,
117143
benchmark_min_time=_BENCHMARK_MIN_TIME.value,

fleetbench/parallel/parallel_bench_lib.py

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,88 @@
3232
from fleetbench.parallel import worker
3333

3434

35-
def _GetBenchmarks(
35+
def _CreateBenchmarks(
36+
bm_target: str, names: list[str]
37+
) -> dict[str, bm.Benchmark]:
38+
"""Creates benchmark dictionary with the benchmark name as the key."""
39+
benchmarks = {}
40+
for name in names:
41+
benchmark = bm.Benchmark(bm_target, name)
42+
benchmarks[benchmark.Name()] = benchmark
43+
return benchmarks
44+
45+
46+
def _CreateMatchingBenchmarks(
47+
bm_target: str, bm_filter: str, bm_candidates: list[str]
48+
) -> dict[str, bm.Benchmark]:
49+
"""Creates benchmarks that match the given filter."""
50+
matching_bm_names = [name for name in bm_candidates if bm_filter in name]
51+
if not matching_bm_names:
52+
raise ValueError(f"Can't find benchmarks matching {bm_filter}.")
53+
return _CreateBenchmarks(bm_target, matching_bm_names)
54+
55+
56+
def _GetDefaultBenchmarks(
3657
benchmark_target: str, benchmark_filters: list[str]
3758
) -> dict[str, bm.Benchmark]:
38-
"""Get a list of benchmarks from the given target."""
59+
"""Get a list of benchmarks from the default target.
60+
61+
Filtering options:
62+
- Empty list: Returns all default benchmarks.
63+
- Keyword list: Returns benchmarks from the default list matching the provided
64+
keyword (e.g., "Cold Hot").
65+
"""
3966
benchmarks = {}
4067
sub_benchmarks = bm.GetSubBenchmarks(benchmark_target)
4168

69+
# Gets default benchmark sets
4270
if not benchmark_filters:
43-
for name in sub_benchmarks:
44-
benchmark = bm.Benchmark(benchmark_target, name)
45-
benchmarks[benchmark.Name()] = benchmark
46-
else:
47-
for bm_filter in benchmark_filters:
48-
matching_bm_names = [name for name in sub_benchmarks if bm_filter in name]
49-
if matching_bm_names:
50-
for name in matching_bm_names:
51-
benchmark = bm.Benchmark(benchmark_target, name)
52-
benchmarks[benchmark.Name()] = benchmark
53-
else:
54-
raise ValueError(
55-
f"Benchmark {bm_filter} not found in {benchmark_target}."
71+
return _CreateBenchmarks(benchmark_target, sub_benchmarks)
72+
73+
# Gets benchmark sets from filters
74+
for bm_filter in benchmark_filters:
75+
benchmarks.update(
76+
_CreateMatchingBenchmarks(benchmark_target, bm_filter, sub_benchmarks)
77+
)
78+
return benchmarks
79+
80+
81+
def _GetWorkloadBenchmarks(
82+
benchmark_target: str, workload_filters: list[str]
83+
) -> dict[str, bm.Benchmark]:
84+
"""Get a list of benchmarks from the given workload that match the filters.
85+
86+
Filtering options:
87+
- Workload name + keyword(s): Returns benchmarks associated with the
88+
specified workload, further filtered by keywords (e.g.,
89+
"libc,Memcpy,Memcmp").
90+
- Workload name + "all": Returns all benchmarks associated with the
91+
specified workload (e.g., "proto,all").
92+
"""
93+
benchmarks = {}
94+
95+
# Get all unique workloads
96+
workloads = bm.GetWorkloads(benchmark_target)
97+
98+
def _GetWorkloadAndFilter(bm_filter: str) -> tuple[str, list[str]]:
99+
parts = bm_filter.split(",")
100+
if parts[0].upper() not in workloads:
101+
raise ValueError(f"Workload {parts[0]} not supported in Fleetbench.")
102+
return parts[0], parts[1:]
103+
104+
for workload_filter in workload_filters:
105+
workload, bm_filters = _GetWorkloadAndFilter(workload_filter)
106+
workload_bms = bm.GetSubBenchmarks(benchmark_target, workload)
107+
if bm_filters == ["all"]:
108+
benchmarks.update(
109+
_CreateMatchingBenchmarks(
110+
benchmark_target, workload.upper(), workload_bms
111+
)
112+
)
113+
else:
114+
for bm_filter in bm_filters:
115+
benchmarks.update(
116+
_CreateMatchingBenchmarks(benchmark_target, bm_filter, workload_bms)
56117
)
57118

58119
return benchmarks
@@ -131,6 +192,7 @@ def _PreRun(
131192
self,
132193
benchmark_target: str,
133194
benchmark_filters: list[str],
195+
workload_filters: list[str],
134196
benchmark_perf_counters: str,
135197
benchmark_repetitions: int,
136198
benchmark_min_time: str,
@@ -139,7 +201,14 @@ def _PreRun(
139201

140202
logging.info("Initializing benchmarks and worker threads...")
141203

142-
self.benchmarks = _GetBenchmarks(benchmark_target, benchmark_filters)
204+
if workload_filters:
205+
self.benchmarks = _GetWorkloadBenchmarks(
206+
benchmark_target, workload_filters
207+
)
208+
else:
209+
self.benchmarks = _GetDefaultBenchmarks(
210+
benchmark_target, benchmark_filters
211+
)
143212

144213
benchmark_flags = _SetExtraBenchmarkFlags(
145214
benchmark_perf_counters, benchmark_repetitions, benchmark_min_time
@@ -331,14 +400,25 @@ def Run(
331400
self,
332401
benchmark_target: str,
333402
benchmark_filter: list[str] = [],
403+
workload_filter: list[str] = [],
334404
benchmark_perf_counters: str = "",
335405
benchmark_repetitions: int = 0,
336406
benchmark_min_time: str = "",
337407
) -> list[result.Result]:
338408
"""Run benchmarks in parallel."""
409+
logging.info("Running with benchmark_filter: %s", benchmark_filter)
410+
logging.info("Running with workload_filter: %s", workload_filter)
411+
412+
if benchmark_filter and workload_filter:
413+
logging.warning(
414+
"Both benchmark_filter and workload_filter specified. "
415+
"benchmark_filter will be ignored."
416+
)
417+
339418
self._PreRun(
340419
benchmark_target,
341420
benchmark_filter,
421+
workload_filter,
342422
benchmark_perf_counters,
343423
benchmark_repetitions,
344424
benchmark_min_time,

0 commit comments

Comments
 (0)