Skip to content

Commit 0168d3b

Browse files
liyuying0000copybara-github
authored andcommitted
Create a weights.py file for all functionalities to manipulate benchmark weights.
PiperOrigin-RevId: 736648945 Change-Id: If970a454623bbe9e47e59db8f6c780a2c075294d
1 parent 8bfe762 commit 0168d3b

File tree

7 files changed

+234
-113
lines changed

7 files changed

+234
-113
lines changed

fleetbench/parallel/BUILD

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ py_library(
5858
":cpu",
5959
":reporter",
6060
":run",
61+
":weights",
6162
":worker",
6263
"@com_google_absl_py//absl/logging",
6364
requirement("numpy"),
@@ -74,6 +75,25 @@ py_library(
7475
],
7576
)
7677

78+
py_library(
79+
name = "weights",
80+
srcs = ["weights.py"],
81+
deps = [
82+
":benchmark",
83+
"@com_google_absl_py//absl/logging",
84+
],
85+
)
86+
87+
py_test(
88+
name = "weights_test",
89+
srcs = ["weights_test.py"],
90+
deps = [
91+
":benchmark",
92+
":weights",
93+
"@com_google_absl_py//absl/testing:absltest",
94+
],
95+
)
96+
7797
py_binary(
7898
name = "parallel_bench",
7999
testonly = True,

fleetbench/parallel/benchmark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def Name(self):
122122
def Path(self):
123123
return f"{self._path}"
124124

125+
def BenchmarkName(self):
126+
return self._benchmark_filter
127+
125128
def __str__(self):
126129
return self.Name()
127130

fleetbench/parallel/parallel_bench.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
)
106106

107107

108-
_BENCHMARK_WEIGHTS = flags.DEFINE_multi_string(
108+
_CUSTOM_BENCHMARK_WEIGHTS = flags.DEFINE_multi_string(
109109
"benchmark_weights",
110110
[],
111111
"Weights for selected benchmarks. Default weight of 1.0 is used if not"
@@ -154,25 +154,22 @@ def main(argv: Sequence[str]) -> None:
154154
)
155155
cpus, target_utilization = scheduling_mode.SelectCPURangeAndSetUtilization()
156156

157-
# Parse benchmark weights.
158-
benchmark_weights = parallel_bench_lib.ParseBenchmarkWeights(
159-
_BENCHMARK_WEIGHTS.value
160-
)
161-
logging.info("Running with selected benchmark weights: %s", benchmark_weights)
162-
163157
bench = parallel_bench_lib.ParallelBench(
164158
cpus=cpus,
165159
cpu_affinity=_CPU_AFFINITY.value,
166-
benchmark_weights=benchmark_weights,
167160
utilization=target_utilization,
168161
duration=_DURATION.value,
169162
temp_root=_TEMP_ROOT.value,
170163
)
171164

165+
bench.SetWeights(
166+
_BENCHMARK_TARGET.value,
167+
_BENCHMARK_FILTER.value,
168+
_WORKLOAD_FILTER.value,
169+
_CUSTOM_BENCHMARK_WEIGHTS.value,
170+
)
171+
172172
bench.Run(
173-
benchmark_target=_BENCHMARK_TARGET.value,
174-
benchmark_filter=_BENCHMARK_FILTER.value,
175-
workload_filter=_WORKLOAD_FILTER.value,
176173
benchmark_perf_counters=_BENCHMARK_PERF_COUNTERS.value,
177174
benchmark_repetitions=_BENCHMARK_REPETITIONS.value,
178175
benchmark_min_time=_BENCHMARK_MIN_TIME.value,

fleetbench/parallel/parallel_bench_lib.py

Lines changed: 31 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,10 @@
2828
from fleetbench.parallel import cpu
2929
from fleetbench.parallel import reporter
3030
from fleetbench.parallel import run
31+
from fleetbench.parallel import weights
3132
from fleetbench.parallel import worker
3233

3334

34-
def ParseBenchmarkWeights(
35-
benchmark_weights_list: list[str],
36-
) -> dict[str, float] | None:
37-
"""Parses a list of benchmark weights into a dictionary.
38-
39-
The string element in the list should be in the format:
40-
<benchmark_name|benchmark_filter>:<weight>. Note that the benchmark name or
41-
filter should be in ALL CAPS to ensure case-insensitive matching.
42-
43-
Args:
44-
benchmark_weights_list: A list of strings to parse.
45-
46-
Returns:
47-
A dictionary of {capitalized string: float} or None if the list is empty.
48-
"""
49-
if not benchmark_weights_list:
50-
return None
51-
52-
benchmark_weights = {}
53-
for weights in benchmark_weights_list:
54-
try:
55-
key, value_str = weights.split(":")
56-
benchmark_weights[key.upper()] = float(value_str)
57-
except ValueError:
58-
logging.warning(
59-
"Invalid benchmark string: %s. The format should be"
60-
" <benchmark_name|benchmark_filter>:<weight>. Skipping...",
61-
weights,
62-
)
63-
64-
return benchmark_weights
65-
66-
6735
def _SetExtraBenchmarkFlags(
6836
benchmark_perf_counters: str,
6937
benchmark_repetitions: int,
@@ -119,7 +87,6 @@ def __init__(
11987
self,
12088
cpus: list[int],
12189
cpu_affinity: bool,
122-
benchmark_weights: dict[str, float] | None,
12390
utilization: float,
12491
duration: float,
12592
temp_root: str,
@@ -133,7 +100,7 @@ def __init__(
133100
self.controller_cpu = cpus[0]
134101
self.cpus = cpus[1:]
135102
self.cpu_affinity = cpu_affinity
136-
self.benchmark_weights = benchmark_weights
103+
self.benchmark_weights: dict[str, float] = {}
137104
self.benchmarks: dict[str, bm.Benchmark] = {}
138105
self.target_utilization = utilization * 100
139106
self.duration = duration
@@ -142,27 +109,44 @@ def __init__(
142109
self.workers: dict[int, worker.Worker] = {}
143110
self.utilization_samples: list[tuple[pd.Timestamp, float]] = []
144111

145-
def _PreRun(
112+
def SetWeights(
146113
self,
147114
benchmark_target: str,
148-
benchmark_filters: list[str] | None,
149-
workload_filters: list[str] | None,
150-
benchmark_perf_counters: str,
151-
benchmark_repetitions: int,
152-
benchmark_min_time: str,
115+
benchmark_filter: list[str] | None,
116+
workload_filter: list[str] | None,
117+
custom_benchmark_weights: list[str] | None,
153118
) -> None:
154-
"""Initial configuration steps."""
155-
156-
logging.info("Initializing benchmarks and worker threads...")
119+
"""Sets the benchmark weights."""
120+
logging.info("Running with benchmark_filter: %s", benchmark_filter)
121+
logging.info("Running with workload_filter: %s", workload_filter)
157122

158-
if workload_filters:
123+
if benchmark_filter and workload_filter:
124+
logging.warning(
125+
"Both benchmark_filter and workload_filter specified. "
126+
"benchmark_filter will be ignored."
127+
)
128+
if workload_filter:
159129
self.benchmarks = bm.GetWorkloadBenchmarks(
160-
benchmark_target, workload_filters
130+
benchmark_target, workload_filter
161131
)
162132
else:
163133
self.benchmarks = bm.GetDefaultBenchmarks(
164-
benchmark_target, benchmark_filters
134+
benchmark_target, benchmark_filter
165135
)
136+
# Gets the number of workloads and num of benchmark for each workload
137+
self.benchmark_weights = weights.GetBenchmarkWeights(
138+
self.benchmarks, custom_benchmark_weights
139+
)
140+
141+
def _PreRun(
142+
self,
143+
benchmark_perf_counters: str,
144+
benchmark_repetitions: int,
145+
benchmark_min_time: str,
146+
) -> None:
147+
"""Initial configuration steps."""
148+
149+
logging.info("Initializing benchmarks and worker threads...")
166150

167151
benchmark_flags = _SetExtraBenchmarkFlags(
168152
benchmark_perf_counters, benchmark_repetitions, benchmark_min_time
@@ -402,27 +386,13 @@ def PostProcessBenchmarkResults(self, benchmark_perf_counters: str) -> None:
402386

403387
def Run(
404388
self,
405-
benchmark_target: str,
406-
benchmark_filter: list[str] | None = None,
407-
workload_filter: list[str] | None = None,
408389
benchmark_perf_counters: str = "",
409390
benchmark_repetitions: int = 0,
410391
benchmark_min_time: str = "",
411392
):
412393
"""Run benchmarks in parallel."""
413-
logging.info("Running with benchmark_filter: %s", benchmark_filter)
414-
logging.info("Running with workload_filter: %s", workload_filter)
415-
416-
if benchmark_filter and workload_filter:
417-
logging.warning(
418-
"Both benchmark_filter and workload_filter specified. "
419-
"benchmark_filter will be ignored."
420-
)
421394

422395
self._PreRun(
423-
benchmark_target,
424-
benchmark_filter,
425-
workload_filter,
426396
benchmark_perf_counters,
427397
benchmark_repetitions,
428398
benchmark_min_time,

fleetbench/parallel/parallel_bench_lib_test.py

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def setUp(self):
3535
self.pb = parallel_bench_lib.ParallelBench(
3636
cpus=[0, 1],
3737
cpu_affinity=False,
38-
benchmark_weights=None,
3938
utilization=0.5,
4039
duration=0.1,
4140
temp_root=absltest.get_default_test_tmpdir(),
@@ -89,12 +88,17 @@ def fake_utilization(unused_cpus):
8988
self.pb = parallel_bench_lib.ParallelBench(
9089
cpus=[0, 1],
9190
cpu_affinity=False,
92-
benchmark_weights=None,
9391
utilization=0.5,
9492
duration=0.1,
9593
temp_root=absltest.get_default_test_tmpdir(),
9694
)
97-
self.pb.Run("fake_bench", [])
95+
self.pb.SetWeights(
96+
benchmark_target="fake_bench",
97+
benchmark_filter=None,
98+
workload_filter=None,
99+
custom_benchmark_weights=None,
100+
)
101+
self.pb.Run()
98102
mock_execute.assert_called_once()
99103

100104
def test_set_extra_benchmark_flags(self):
@@ -241,43 +245,5 @@ def test_post_processing_benchmark_results(
241245
mock_save_benchmark_results.assert_called_once()
242246

243247

244-
class ParseBenchmarkWeightsTest(absltest.TestCase):
245-
246-
def test_empty_list(self):
247-
self.assertIsNone(parallel_bench_lib.ParseBenchmarkWeights([]))
248-
249-
def test_valid_list(self):
250-
benchmark_list = [
251-
"cold:0.5",
252-
"PROTO:0.3",
253-
"Cord:0.2",
254-
]
255-
expected_output = {
256-
"COLD": 0.5,
257-
"PROTO": 0.3,
258-
"CORD": 0.2,
259-
}
260-
self.assertEqual(
261-
parallel_bench_lib.ParseBenchmarkWeights(benchmark_list),
262-
expected_output,
263-
"Should return correct dictionary for valid list",
264-
)
265-
266-
def test_invalid_string(self):
267-
benchmark_list = [
268-
"cold:0.5",
269-
"PROTO:invalid",
270-
"TCMALLOC",
271-
"Cord:0.2",
272-
]
273-
# Even with an invalid string, the function should still process the valid
274-
# ones.
275-
expected_output = {"COLD": 0.5, "CORD": 0.2}
276-
self.assertEqual(
277-
parallel_bench_lib.ParseBenchmarkWeights(benchmark_list),
278-
expected_output,
279-
)
280-
281-
282248
if __name__ == "__main__":
283249
absltest.main()

fleetbench/parallel/weights.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 The Fleetbench Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""This file contains functions for handling benchmark weights."""
17+
18+
from absl import logging
19+
from fleetbench.parallel import benchmark as bm
20+
21+
22+
def _ParseBenchmarkWeights(
23+
benchmark_weights_list: list[str],
24+
) -> dict[str, float] | None:
25+
"""Parses a list of benchmark weights into a dictionary.
26+
27+
The string element in the list should be in the format:
28+
<benchmark_name|benchmark_filter>:<weight>. Note that the benchmark name or
29+
filter should be in ALL CAPS to ensure case-insensitive matching.
30+
31+
Args:
32+
benchmark_weights_list: A list of strings to parse.
33+
34+
Returns:
35+
A dictionary of {capitalized string: float} or None if the list is empty.
36+
"""
37+
if not benchmark_weights_list:
38+
return None
39+
40+
benchmark_weights = {}
41+
for weights in benchmark_weights_list:
42+
try:
43+
key, value_str = weights.split(":")
44+
benchmark_weights[key.upper()] = float(value_str)
45+
except ValueError:
46+
logging.warning(
47+
"Invalid benchmark string: %s. The format should be"
48+
" <benchmark_name|benchmark_filter>:<weight>. Skipping...",
49+
weights,
50+
)
51+
52+
return benchmark_weights
53+
54+
55+
def GetBenchmarkWeights(
56+
benchmarks: dict[str, bm.Benchmark],
57+
custom_weight_strings: list[str] | None = None,
58+
) -> dict[str, float]:
59+
"""Get the weights for each benchmark.
60+
61+
The default weight for each benchmark is 1.0. If custom_weight_strings is
62+
provided, the weights for the benchmarks in the list will be updated.
63+
64+
Args:
65+
benchmarks: A dictionary of {benchmark name: Benchmark object}.
66+
custom_weight_strings: A list of strings to parse for custom weights.
67+
68+
Returns:
69+
A dictionary of {benchmark name: weight}.
70+
"""
71+
benchmark_weights = {
72+
benchmark.BenchmarkName(): 1.0 for benchmark in benchmarks.values()
73+
}
74+
75+
# Update the benchmark weights with the custom weights.
76+
if custom_weight_strings:
77+
logging.info(
78+
"Running with selected benchmark weights: %s", custom_weight_strings
79+
)
80+
custom_weights = _ParseBenchmarkWeights(custom_weight_strings)
81+
if custom_weights:
82+
benchmark_weights.update(custom_weights)
83+
return benchmark_weights

0 commit comments

Comments
 (0)