Skip to content

Commit 226bc60

Browse files
authored
other(tests): tests for bucketing manager (#521)
1 parent 9b64414 commit 226bc60

File tree

1 file changed

+207
-0
lines changed

1 file changed

+207
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from vllm_rbln.v1.worker.bucketing import (
18+
ExponentialBucketingManager,
19+
LinearBucketingManager,
20+
ManualBucketingManager,
21+
RBLNBucketingManager,
22+
get_bucketing_manager,
23+
)
24+
25+
26+
class DummyManager(RBLNBucketingManager):
27+
def _build_decode_buckets(self):
28+
self.decode_batch_buckets = [2, 4, 8]
29+
30+
31+
def test_base_manager_properties_and_find():
32+
manager = DummyManager(max_batch_size=8)
33+
assert manager.decode_batch_buckets == [2, 4, 8]
34+
assert manager.batch_buckets == [1, 2, 4, 8]
35+
assert manager.decode_batch_buckets_count == 3
36+
assert manager.batch_buckets_count == 4
37+
assert manager.find_decode_batch_bucket(1) == 2
38+
assert manager.find_decode_batch_bucket(4) == 4
39+
assert manager.find_decode_batch_bucket(7) == 8
40+
41+
42+
def test_base_manager_find_decode_bucket_not_found():
43+
manager = DummyManager(max_batch_size=8)
44+
with pytest.raises(ValueError, match="No batch bucket found"):
45+
manager.find_decode_batch_bucket(9)
46+
47+
48+
def test_base_manager_abstract_method_not_defined():
49+
class Foo(RBLNBucketingManager):
50+
pass
51+
52+
with pytest.raises(TypeError):
53+
Foo(1)
54+
55+
class Bar(RBLNBucketingManager):
56+
def _build_decode_buckets(self):
57+
super()._build_decode_buckets()
58+
59+
with pytest.raises(
60+
NotImplementedError, match="Subclasses must implement this method"
61+
):
62+
Bar(1)
63+
64+
65+
@pytest.mark.parametrize(
66+
"kwargs, expected_error",
67+
[
68+
pytest.param(
69+
{"max_batch_size": 1, "min_batch_size": 2, "limit": 1, "step": 1},
70+
"max_batch_size must be >= min_batch_size",
71+
id="max_lt_min",
72+
),
73+
pytest.param(
74+
{"max_batch_size": 2, "min_batch_size": 1, "limit": 0, "step": 1},
75+
"limit must be greater than 0",
76+
id="non_positive_limit",
77+
),
78+
pytest.param(
79+
{"max_batch_size": 2, "min_batch_size": 1, "limit": 1, "step": 0},
80+
"step must be greater than 0",
81+
id="non_positive_step",
82+
),
83+
pytest.param(
84+
{"max_batch_size": 2, "min_batch_size": 0, "limit": 1, "step": 1},
85+
"min_batch_size must be greater than 0",
86+
id="non_positive_min",
87+
),
88+
],
89+
)
90+
def test_check_config_raises_for_invalid_config(kwargs, expected_error):
91+
with pytest.raises(ValueError, match=expected_error):
92+
RBLNBucketingManager.check_config(**kwargs)
93+
94+
95+
def test_check_config_allows_valid_config():
96+
RBLNBucketingManager.check_config(
97+
max_batch_size=8,
98+
min_batch_size=1,
99+
limit=4,
100+
step=2,
101+
)
102+
103+
104+
def test_exponential_bucketing_manager_builds_and_stops_at_limit():
105+
manager = ExponentialBucketingManager(
106+
max_batch_size=64,
107+
min_batch_size=4,
108+
limit=4,
109+
step=2,
110+
)
111+
assert manager.decode_batch_buckets == [8, 16, 32, 64]
112+
assert manager.batch_buckets == [1, 8, 16, 32, 64]
113+
114+
115+
def test_exponential_bucketing_manager_breaks_when_under_minimum():
116+
manager = ExponentialBucketingManager(
117+
max_batch_size=10,
118+
min_batch_size=6,
119+
limit=5,
120+
step=2,
121+
)
122+
assert manager.decode_batch_buckets == [10]
123+
124+
125+
def test_exponential_bucketing_manager_requires_step_over_one():
126+
with pytest.raises(ValueError, match="step must be greater than 1"):
127+
ExponentialBucketingManager(
128+
max_batch_size=8,
129+
min_batch_size=1,
130+
limit=2,
131+
step=1,
132+
)
133+
134+
135+
def test_linear_bucketing_manager_builds_and_stops_at_limit():
136+
manager = LinearBucketingManager(
137+
max_batch_size=10,
138+
min_batch_size=1,
139+
limit=4,
140+
step=3,
141+
)
142+
assert manager.decode_batch_buckets == [1, 4, 7, 10]
143+
assert manager.batch_buckets == [1, 4, 7, 10]
144+
145+
146+
def test_linear_bucketing_manager_breaks_when_under_minimum():
147+
manager = LinearBucketingManager(
148+
max_batch_size=10,
149+
min_batch_size=8,
150+
limit=5,
151+
step=3,
152+
)
153+
assert manager.decode_batch_buckets == [10]
154+
155+
156+
def test_manual_bucketing_manager_builds_sorted_unique_buckets():
157+
manager = ManualBucketingManager(
158+
max_batch_size=8,
159+
manual_buckets=[8, 2, 4, 8],
160+
)
161+
assert manager.decode_batch_buckets == [2, 4, 8]
162+
assert manager.batch_buckets == [1, 2, 4, 8]
163+
164+
165+
def test_manual_bucketing_manager_requires_non_empty_buckets():
166+
with pytest.raises(AssertionError, match="manual_buckets must be non-empty"):
167+
ManualBucketingManager(max_batch_size=8, manual_buckets=[])
168+
with pytest.raises(AssertionError, match="manual_buckets must be non-empty"):
169+
get_bucketing_manager("manual", max_batch_size=8)
170+
171+
172+
def test_manual_bucketing_manager_requires_last_bucket_to_match_max():
173+
with pytest.raises(ValueError, match="last manual bucket"):
174+
ManualBucketingManager(
175+
max_batch_size=8,
176+
manual_buckets=[2, 4, 7],
177+
)
178+
179+
180+
def test_get_bucketing_manager_for_all_strategies():
181+
exp_manager = get_bucketing_manager(
182+
"exponential",
183+
max_batch_size=8,
184+
min_batch_size=1,
185+
limit=3,
186+
step=2,
187+
)
188+
linear_manager = get_bucketing_manager(
189+
"linear",
190+
max_batch_size=8,
191+
min_batch_size=1,
192+
limit=3,
193+
step=2,
194+
)
195+
manual_manager = get_bucketing_manager(
196+
"manual",
197+
max_batch_size=8,
198+
manual_buckets=[2, 8],
199+
)
200+
assert isinstance(exp_manager, ExponentialBucketingManager)
201+
assert isinstance(linear_manager, LinearBucketingManager)
202+
assert isinstance(manual_manager, ManualBucketingManager)
203+
204+
205+
def test_get_bucketing_manager_rejects_invalid_strategy():
206+
with pytest.raises(ValueError, match="Invalid bucketing strategy"):
207+
get_bucketing_manager("unknown", max_batch_size=8)

0 commit comments

Comments
 (0)