Skip to content

Commit 9c7a55a

Browse files
authored
fix: preserve per-request vllm sampling params (#1326)
1 parent bd71e82 commit 9c7a55a

3 files changed

Lines changed: 184 additions & 6 deletions

File tree

lmms_eval/models/chat/vllm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,24 +93,27 @@ def generate_until(self, requests) -> List[GenerationResult]:
9393
sample_token_counts: Optional[TokenCounts] = None
9494
for batch_requests in batched_requests:
9595
batched_messages = []
96+
batched_sampling_params = []
9697
with ThreadPoolExecutor(max_workers=WORKERS) as executor:
9798
futures = [executor.submit(self.make_one_request, request) for request in batch_requests]
9899
for future in futures:
99100
messages, sampling_params = future.result()
100101
batched_messages.append(messages)
102+
batched_sampling_params.append(sampling_params)
101103

102-
sampling_params = SamplingParams(**sampling_params)
103104
start_time = time.time()
104105

105-
def _run_chat(inputs: list[dict]) -> list[str]:
106+
def _run_chat(request_items: list[tuple[list[dict], dict]]) -> list[str]:
107+
inputs = [messages for messages, _ in request_items]
108+
sampling_params = [SamplingParams(**params) for _, params in request_items]
106109
response = self.client.chat(
107110
sampling_params=sampling_params,
108111
messages=inputs,
109112
chat_template=self.chat_template,
110113
)
111114
return [o.outputs[0].text for o in response]
112115

113-
response_text = self._run_tp_synced(batched_messages, _run_chat)
116+
response_text = self._run_tp_synced(list(zip(batched_messages, batched_sampling_params)), _run_chat)
114117
end_time = time.time()
115118

116119
# Calculate timing metrics for batch

lmms_eval/models/chat/vllm_generate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,23 @@ def generate_until(self, requests) -> List[GenerationResult]:
159159
sample_token_counts: Optional[TokenCounts] = None
160160
for batch_requests in batched_requests:
161161
batched_vllm_inputs = []
162+
batched_sampling_params = []
162163
with ThreadPoolExecutor(max_workers=WORKERS) as executor:
163164
futures = [executor.submit(self.make_one_request, request) for request in batch_requests]
164165
for future in futures:
165166
vllm_inputs, sampling_params = future.result()
166167
batched_vllm_inputs.append(vllm_inputs)
168+
batched_sampling_params.append(sampling_params)
167169

168-
sampling_params = SamplingParams(**sampling_params)
169170
start_time = time.time()
170171

171-
def _run_generate(inputs: list[dict]) -> list[str]:
172+
def _run_generate(request_items: list[tuple[dict, dict]]) -> list[str]:
173+
inputs = [vllm_inputs for vllm_inputs, _ in request_items]
174+
sampling_params = [SamplingParams(**params) for _, params in request_items]
172175
response = self.client.generate(inputs, sampling_params)
173176
return [o.outputs[0].text for o in response]
174177

175-
response_text = self._run_tp_synced(batched_vllm_inputs, _run_generate)
178+
response_text = self._run_tp_synced(list(zip(batched_vllm_inputs, batched_sampling_params)), _run_generate)
176179
end_time = time.time()
177180

178181
# Calculate timing metrics for batch
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from __future__ import annotations
2+
3+
import importlib.util
4+
import pathlib
5+
import sys
6+
import types
7+
import unittest
8+
from unittest.mock import patch
9+
10+
11+
def _install_vllm_stubs() -> None:
12+
class _FakeVLLMSimple:
13+
@property
14+
def rank(self):
15+
return self._rank
16+
17+
def _run_tp_synced(self, local_inputs, run_fn):
18+
return run_fn(local_inputs)
19+
20+
modules = {
21+
"lmms_eval.api.instance": types.SimpleNamespace(
22+
GenerationResult=lambda text, token_counts=None: types.SimpleNamespace(text=text, token_counts=token_counts),
23+
Instance=object,
24+
TokenCounts=object,
25+
),
26+
"lmms_eval.api.registry": types.SimpleNamespace(register_model=lambda _name: (lambda cls: cls)),
27+
"lmms_eval.imports": types.SimpleNamespace(optional_import=lambda *_args: (None, False)),
28+
"lmms_eval.models.model_utils.gen_metrics": types.SimpleNamespace(log_metrics=lambda **_kwargs: None),
29+
"lmms_eval.models.simple.vllm": types.SimpleNamespace(VLLM=_FakeVLLMSimple),
30+
"lmms_eval.protocol": types.SimpleNamespace(ChatMessages=object),
31+
}
32+
for name, module in modules.items():
33+
sys.modules[name] = module if isinstance(module, types.ModuleType) else _namespace_module(name, module)
34+
35+
36+
def _namespace_module(name: str, namespace) -> types.ModuleType:
37+
module = types.ModuleType(name)
38+
module.__dict__.update(vars(namespace))
39+
return module
40+
41+
42+
def _load_module(module_name: str, relative_path: str):
43+
repo_root = pathlib.Path(__file__).resolve().parents[2]
44+
module_path = repo_root / relative_path
45+
spec = importlib.util.spec_from_file_location(module_name, module_path)
46+
assert spec is not None
47+
assert spec.loader is not None
48+
module = importlib.util.module_from_spec(spec)
49+
sys.modules[module_name] = module
50+
spec.loader.exec_module(module)
51+
return module
52+
53+
54+
_STUBBED_MODULES = (
55+
"lmms_eval.api.instance",
56+
"lmms_eval.api.registry",
57+
"lmms_eval.imports",
58+
"lmms_eval.models.model_utils.gen_metrics",
59+
"lmms_eval.models.simple.vllm",
60+
"lmms_eval.protocol",
61+
"lmms_eval.models.chat.vllm",
62+
"lmms_eval.models.chat.vllm_generate",
63+
)
64+
_original_modules = {name: sys.modules.get(name) for name in _STUBBED_MODULES}
65+
try:
66+
_install_vllm_stubs()
67+
_vllm_chat = _load_module("lmms_eval.models.chat.vllm", "lmms_eval/models/chat/vllm.py")
68+
_vllm_generate = _load_module("lmms_eval.models.chat.vllm_generate", "lmms_eval/models/chat/vllm_generate.py")
69+
finally:
70+
for name, module in _original_modules.items():
71+
if module is None:
72+
sys.modules.pop(name, None)
73+
else:
74+
sys.modules[name] = module
75+
76+
VLLMChat = _vllm_chat.VLLM
77+
VLLMGenerate = _vllm_generate.VLLMGenerate
78+
79+
80+
class _FakeSamplingParams:
81+
def __init__(self, **kwargs):
82+
self.kwargs = kwargs
83+
84+
85+
class _CaptureClient:
86+
def __init__(self):
87+
self.calls = []
88+
89+
def chat(self, *, messages, sampling_params, chat_template):
90+
self.calls.append(
91+
{
92+
"messages": messages,
93+
"sampling_params": sampling_params,
94+
"chat_template": chat_template,
95+
}
96+
)
97+
return [types.SimpleNamespace(outputs=[types.SimpleNamespace(text=f"chat-{idx}")]) for idx, _ in enumerate(messages)]
98+
99+
def generate(self, inputs, sampling_params):
100+
self.calls.append(
101+
{
102+
"inputs": inputs,
103+
"sampling_params": sampling_params,
104+
}
105+
)
106+
return [types.SimpleNamespace(outputs=[types.SimpleNamespace(text=f"generate-{idx}")]) for idx, _ in enumerate(inputs)]
107+
108+
109+
def _request(name: str):
110+
return types.SimpleNamespace(name=name)
111+
112+
113+
def _configure_model(model, client: _CaptureClient) -> None:
114+
model.client = client
115+
model.batch_size_per_gpu = 2
116+
model._rank = 0
117+
model._tp_world_size = 1
118+
model._tp_group_cpu = None
119+
model.disable_log_stats = True
120+
model.chat_template = None
121+
122+
123+
class TestVLLMSamplingParams(unittest.TestCase):
124+
def test_chat_backend_keeps_per_request_sampling_params(self):
125+
client = _CaptureClient()
126+
model = VLLMChat.__new__(VLLMChat)
127+
_configure_model(model, client)
128+
129+
params_by_request = {
130+
"short": {"max_tokens": 16, "temperature": 0, "top_p": 1.0},
131+
"long": {"max_tokens": 128, "temperature": 0.7, "top_p": 0.8},
132+
}
133+
134+
def make_one_request(request):
135+
return [{"role": "user", "content": request.name}], params_by_request[request.name]
136+
137+
model.make_one_request = make_one_request
138+
139+
with patch.object(_vllm_chat, "SamplingParams", _FakeSamplingParams):
140+
results = model.generate_until([_request("short"), _request("long")])
141+
142+
self.assertEqual([result.text for result in results], ["chat-0", "chat-1"])
143+
self.assertEqual(len(client.calls), 1)
144+
sent_params = client.calls[0]["sampling_params"]
145+
self.assertEqual([params.kwargs for params in sent_params], [params_by_request["short"], params_by_request["long"]])
146+
147+
def test_generate_backend_keeps_per_request_sampling_params(self):
148+
client = _CaptureClient()
149+
model = VLLMGenerate.__new__(VLLMGenerate)
150+
_configure_model(model, client)
151+
152+
params_by_request = {
153+
"ocr": {"max_tokens": 128, "temperature": 0, "top_p": 1.0},
154+
"vqa": {"max_tokens": 32, "temperature": 0.2, "top_p": 0.9},
155+
}
156+
157+
def make_one_request(request):
158+
return {"prompt": request.name, "multi_modal_data": {}}, params_by_request[request.name]
159+
160+
model.make_one_request = make_one_request
161+
162+
with patch.object(_vllm_generate, "SamplingParams", _FakeSamplingParams):
163+
results = model.generate_until([_request("ocr"), _request("vqa")])
164+
165+
self.assertEqual([result.text for result in results], ["generate-0", "generate-1"])
166+
self.assertEqual(len(client.calls), 1)
167+
sent_params = client.calls[0]["sampling_params"]
168+
self.assertEqual([params.kwargs for params in sent_params], [params_by_request["ocr"], params_by_request["vqa"]])
169+
170+
171+
if __name__ == "__main__":
172+
unittest.main()

0 commit comments

Comments
 (0)