Skip to content

Commit 9dee01d

Browse files
liyuying0000copybara-github
authored andcommitted
Implement custom weights support for other scheduling strategies.
PiperOrigin-RevId: 741572656 Change-Id: Ic0a9dd124157991427be94b3dddb204262d8e097
1 parent 8cd2a72 commit 9dee01d

File tree

2 files changed

+100
-30
lines changed

2 files changed

+100
-30
lines changed

fleetbench/parallel/weights.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,35 @@ def GetDCTaxWeights() -> dict[str, float]:
8080
}
8181

8282

83+
def _UpdateWorkloadWeights(
84+
benchmark_weights: dict[str, float],
85+
workload_benchmark: dict[str, list[bm.Benchmark]],
86+
custom_weights: dict[str, float],
87+
) -> dict[str, float]:
88+
"""Updates benchmark weights based on workload and custom weights."""
89+
for benchmarks in workload_benchmark.values():
90+
# If there is only one benchmark in the workload, there is no need to
91+
# apply the custom weights.
92+
if len(benchmarks) == 1:
93+
continue
94+
95+
for benchmark in benchmarks:
96+
for keyword, weight in custom_weights.items():
97+
if keyword in benchmark.BenchmarkName().upper():
98+
# Update the weight for the benchmark in the custom weights list.
99+
benchmark_weights[benchmark.BenchmarkName()] *= weight
100+
101+
# Normalize the weights again for each workload.
102+
# We need first sum all weights for current workload, then divide
103+
# individual weights by sum.
104+
sum_weights = sum(
105+
benchmark_weights[benchmark.BenchmarkName()] for benchmark in benchmarks
106+
)
107+
for benchmark in benchmarks:
108+
benchmark_weights[benchmark.BenchmarkName()] /= sum_weights
109+
return benchmark_weights
110+
111+
83112
def GetBenchmarkWeights(
84113
benchmarks: dict[str, bm.Benchmark],
85114
scheduling_strategy: SchedulingStrategy,
@@ -111,8 +140,6 @@ def GetBenchmarkWeights(
111140
benchmark_weights = {
112141
benchmark.BenchmarkName(): 1.0 for benchmark in benchmarks.values()
113142
}
114-
# TODO: add custom weights support for the other scheduling
115-
# strategies.
116143
# Update the benchmark weights with the custom weights.
117144
if custom_weights:
118145
benchmark_weights.update(custom_weights)
@@ -136,10 +163,22 @@ def GetBenchmarkWeights(
136163
for benchmark in benchmarks:
137164
benchmark_weights[benchmark.BenchmarkName()] = 1 / len(benchmarks)
138165

166+
# Update the benchmark weights with the custom weights if provided.
167+
if custom_weights:
168+
benchmark_weights = _UpdateWorkloadWeights(
169+
benchmark_weights, workload_benchmark, custom_weights
170+
)
171+
139172
elif scheduling_strategy == SchedulingStrategy.DCTAX_WEIGHTED:
140173
# Get the DCTax based weights for each benchmark
141174
# For OSS, feel free to adjust the weights.csv file.
142175
benchmark_weights = GetDCTaxWeights()
176+
177+
if custom_weights:
178+
logging.warning(
179+
"Custom weights are not supported for DCTAX_WEIGHTED scheduling"
180+
" strategy. We will ignore the custom weights."
181+
)
143182
else:
144183
raise ValueError(
145184
"Unsupported scheduling strategy: %s" % scheduling_strategy.name

fleetbench/parallel/weights_test.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -69,33 +69,6 @@ def test_bm_weighted_no_custom_weights(self):
6969
expected_weights,
7070
)
7171

72-
def test_dctax_weighted(self):
73-
benchmark_map = {
74-
"benchmark1": bm.Benchmark("fake_target", "BM_TCMalloc_Test1"),
75-
"benchmark2": bm.Benchmark("fake_target", "BM_Cord_Test2"),
76-
"benchmark3": bm.Benchmark("fake_target", "BM_Protobuf_Test1"),
77-
}
78-
expected_weights = {
79-
"BM_TCMalloc_Test1": 0.4,
80-
"BM_Cord_Test2": 0.1,
81-
"BM_Protobuf_Test1": 0.2,
82-
}
83-
with absltest.mock.patch.object(
84-
weights,
85-
"GetDCTaxWeights",
86-
return_value={
87-
"BM_TCMalloc_Test1": 0.4,
88-
"BM_Cord_Test2": 0.1,
89-
"BM_Protobuf_Test1": 0.2,
90-
},
91-
):
92-
self.assertEqual(
93-
weights.GetBenchmarkWeights(
94-
benchmark_map, weights.SchedulingStrategy.DCTAX_WEIGHTED
95-
),
96-
expected_weights,
97-
)
98-
9972
def test_bm_weighted_custom_weights(self):
10073
benchmark_map = {
10174
"benchmark1": bm.Benchmark("fake_target", "benchmark1"),
@@ -115,7 +88,7 @@ def test_bm_weighted_custom_weights(self):
11588
expected_weights,
11689
)
11790

118-
def test_workload_weighted(self):
91+
def test_workload_weighted_no_custom_weights(self):
11992
benchmark_map = {
12093
"benchmark1": bm.Benchmark("fake_target", "BM_Workload1_Test1"),
12194
"benchmark2": bm.Benchmark("fake_target", "BM_Workload1_Test2"),
@@ -132,3 +105,61 @@ def test_workload_weighted(self):
132105
),
133106
expected_weights,
134107
)
108+
109+
def test_workload_weighted_custom_weights(self):
110+
benchmark_map = {
111+
"benchmark1": bm.Benchmark("fake_target", "BM_Workload1_Test1"),
112+
"benchmark2": bm.Benchmark("fake_target", "BM_Workload1_Test2"),
113+
"benchmark3": bm.Benchmark("fake_target", "BM_Workload2_Test1"),
114+
}
115+
custom_weights = ["Test1:3"]
116+
expected_weights = {
117+
"benchmark1": 0.75,
118+
"benchmark2": 0.25,
119+
"benchmark3": 1.0,
120+
}
121+
self.assertEqual(
122+
weights.GetBenchmarkWeights(
123+
benchmark_map,
124+
weights.SchedulingStrategy.WORKLOAD_WEIGHTED,
125+
custom_weights,
126+
),
127+
expected_weights,
128+
)
129+
130+
def test_dctax_weighted(self):
131+
benchmark_map = {
132+
"benchmark1": bm.Benchmark("fake_target", "BM_TCMalloc_Test1"),
133+
"benchmark2": bm.Benchmark("fake_target", "BM_Cord_Test2"),
134+
"benchmark3": bm.Benchmark("fake_target", "BM_Protobuf_Test1"),
135+
}
136+
expected_weights = {
137+
"BM_TCMalloc_Test1": 0.4,
138+
"BM_Cord_Test2": 0.1,
139+
"BM_Protobuf_Test1": 0.2,
140+
}
141+
custom_weights = ["Test1:3"]
142+
with absltest.mock.patch.object(
143+
weights,
144+
"GetDCTaxWeights",
145+
return_value={
146+
"BM_TCMalloc_Test1": 0.4,
147+
"BM_Cord_Test2": 0.1,
148+
"BM_Protobuf_Test1": 0.2,
149+
},
150+
):
151+
self.assertEqual(
152+
weights.GetBenchmarkWeights(
153+
benchmark_map, weights.SchedulingStrategy.DCTAX_WEIGHTED
154+
),
155+
expected_weights,
156+
)
157+
# Custom weights are ignored for DCTAX_WEIGHTED scheduling strategy.
158+
self.assertEqual(
159+
weights.GetBenchmarkWeights(
160+
benchmark_map,
161+
weights.SchedulingStrategy.DCTAX_WEIGHTED,
162+
custom_weights,
163+
),
164+
expected_weights,
165+
)

0 commit comments

Comments
 (0)