Skip to content

Commit 3aaeb88

Browse files
Luodianclaude
andcommitted
refactor: optimize OpenAI compatible endpoint batch processing
- Remove unused decord imports from chat implementation - Add configurable max_concurrent_requests parameter (default: 32) - Optimize file I/O by writing cache once at the end instead of per batch - Improve thread pool sizing based on actual batch size These changes improve performance and reduce unnecessary file I/O operations while maintaining backward compatibility. Reported-by:b8zhong Github-Issue:#835 Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 0bb1001 commit 3aaeb88

2 files changed

Lines changed: 138 additions & 39 deletions

File tree

lmms_eval/models/chat/openai_compatible.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77

88
from lmms_eval.api.registry import register_model
99

10-
try:
11-
from decord import VideoReader, cpu
12-
except ImportError:
13-
pass
10+
# decord imports removed - not used in chat implementation
1411

1512
from dotenv import load_dotenv
1613
from loguru import logger as eval_logger
@@ -32,8 +29,14 @@ def generate_until(self, requests) -> List[str]:
3229
res = []
3330

3431
batch_size = getattr(self, "batch_size_per_gpu", 1)
35-
batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)]
36-
pbar = tqdm(total=len(batched_requests), disable=(self.rank != 0), desc="Model Responding")
32+
batched_requests = [
33+
requests[i : i + batch_size] for i in range(0, len(requests), batch_size)
34+
]
35+
pbar = tqdm(
36+
total=len(batched_requests),
37+
disable=(self.rank != 0),
38+
desc="Model Responding",
39+
)
3740

3841
e2e_latency = 0
3942
total_tokens = 0
@@ -56,7 +59,9 @@ def generate_until(self, requests) -> List[str]:
5659
continue
5760

5861
chat_messages_raw = doc_to_messages(self.task_dict[task][split][doc_id])
59-
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages_raw})
62+
chat_messages: ChatMessages = ChatMessages(
63+
**{"messages": chat_messages_raw}
64+
)
6065

6166
payload = {"messages": chat_messages.to_openai_messages()}
6267
payload["model"] = self.model_version
@@ -75,7 +80,11 @@ def generate_until(self, requests) -> List[str]:
7580
payload["max_tokens"] = gen_kwargs["max_new_tokens"]
7681
payload["temperature"] = gen_kwargs["temperature"]
7782

78-
if "o1" in self.model_version or "o3" in self.model_version or "o4" in self.model_version:
83+
if (
84+
"o1" in self.model_version
85+
or "o3" in self.model_version
86+
or "o4" in self.model_version
87+
):
7988
del payload["temperature"]
8089
payload.pop("max_tokens")
8190
payload["reasoning_effort"] = "medium"
@@ -108,22 +117,35 @@ def process_single_request(payload, i):
108117

109118
except Exception as e:
110119
error_msg = str(e)
111-
eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}")
120+
eval_logger.info(
121+
f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}"
122+
)
112123

113124
if attempt == self.max_retries - 1:
114-
eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}")
125+
eval_logger.error(
126+
f"All {self.max_retries} attempts failed. Last error: {error_msg}"
127+
)
115128
return "", i, 0, 0
116129
else:
117130
time.sleep(self.timeout)
118131

119132
return "", i, 0, 0
120133

121-
tasks_to_run = [(payload, i) for i, payload in enumerate(batch_payloads) if batch_responses[i] is None]
134+
tasks_to_run = [
135+
(payload, i)
136+
for i, payload in enumerate(batch_payloads)
137+
if batch_responses[i] is None
138+
]
122139

123140
if tasks_to_run:
124-
max_workers = min(len(tasks_to_run), 32)
141+
max_workers = min(
142+
len(tasks_to_run), getattr(self, "max_concurrent_requests", 32)
143+
)
125144
with ThreadPoolExecutor(max_workers=max_workers) as executor:
126-
future_to_index = {executor.submit(process_single_request, payload, i): i for payload, i in tasks_to_run}
145+
future_to_index = {
146+
executor.submit(process_single_request, payload, i): i
147+
for payload, i in tasks_to_run
148+
}
127149

128150
for future in as_completed(future_to_index):
129151
response_text, i, latency, tokens = future.result()
@@ -135,8 +157,6 @@ def process_single_request(payload, i):
135157
for doc_uuid, response_text in zip(batch_doc_uuids, batch_responses):
136158
if response_text is not None:
137159
self.response_cache[doc_uuid] = response_text
138-
with open(self.response_persistent_file, "w") as f:
139-
json.dump(self.response_cache, f)
140160

141161
res.extend([r for r in batch_responses if r is not None])
142162
pbar.update(1)
@@ -152,4 +172,10 @@ def process_single_request(payload, i):
152172
log_metrics(**metric_dict)
153173

154174
pbar.close()
175+
176+
# Write cache once at the end if in continual mode
177+
if self.continual_mode is True:
178+
with open(self.response_persistent_file, "w") as f:
179+
json.dump(self.response_cache, f)
180+
155181
return res

lmms_eval/models/simple/openai_compatible.py

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
max_frames_num: int = 10,
4545
httpx_trust_env: bool = True,
4646
batch_size: int = 64,
47+
max_concurrent_requests: int = 32,
4748
**kwargs,
4849
) -> None:
4950
"""
@@ -57,16 +58,22 @@ def __init__(
5758
self.model_version = model_version
5859
self.timeout = timeout
5960
self.max_retries = max_retries
60-
self.max_size_in_mb = max_size_in_mb # some models have a limit on the size of the image
61+
self.max_size_in_mb = (
62+
max_size_in_mb # some models have a limit on the size of the image
63+
)
6164
self.continual_mode = continual_mode
6265
self.max_frames_num = max_frames_num
6366
if self.continual_mode:
6467
if response_persistent_folder is None:
65-
raise ValueError("Continual mode requires a persistent path for the response. Please provide a valid path.")
68+
raise ValueError(
69+
"Continual mode requires a persistent path for the response. Please provide a valid path."
70+
)
6671

6772
os.makedirs(response_persistent_folder, exist_ok=True)
6873
self.response_persistent_folder = response_persistent_folder
69-
self.response_persistent_file = os.path.join(self.response_persistent_folder, f"{self.model_version}_response.json")
74+
self.response_persistent_file = os.path.join(
75+
self.response_persistent_folder, f"{self.model_version}_response.json"
76+
)
7077

7178
if os.path.exists(self.response_persistent_file):
7279
with open(self.response_persistent_file, "r") as f:
@@ -81,7 +88,11 @@ def __init__(
8188
# settings. openai-python uses a httpx.Client with trust_env set to True. Such a
8289
# httpx.Client uses macOS proxy server settings. Adding httpx_trust_env option
8390
# allows httpx to ignore proxy server settings set by VPN clients.
84-
http_client = DefaultHttpxClient(trust_env=httpx_trust_env) if not httpx_trust_env else None
91+
http_client = (
92+
DefaultHttpxClient(trust_env=httpx_trust_env)
93+
if not httpx_trust_env
94+
else None
95+
)
8596

8697
# Use provided parameters or fall back to environment variables
8798
api_key = api_key or os.getenv("OPENAI_API_KEY")
@@ -98,16 +109,27 @@ def __init__(
98109
self.client = (
99110
OpenAI(api_key=api_key, base_url=base_url, http_client=http_client)
100111
if not azure_openai
101-
else AzureOpenAI(api_key=os.getenv("AZURE_OPENAI_API_KEY"), azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), http_client=http_client)
112+
else AzureOpenAI(
113+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
114+
azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"),
115+
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
116+
http_client=http_client,
117+
)
102118
)
103119

104120
accelerator = Accelerator()
105121
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
106122
if accelerator.num_processes > 1:
107-
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
123+
assert accelerator.distributed_type in [
124+
DistributedType.FSDP,
125+
DistributedType.MULTI_GPU,
126+
DistributedType.DEEPSPEED,
127+
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
108128
self.accelerator = accelerator
109129
if self.accelerator.is_local_main_process:
110-
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
130+
eval_logger.info(
131+
f"Using {accelerator.num_processes} devices with data parallelism"
132+
)
111133
self._rank = self.accelerator.local_process_index
112134
self._world_size = self.accelerator.num_processes
113135
else:
@@ -117,6 +139,7 @@ def __init__(
117139

118140
self.device = self.accelerator.device
119141
self.batch_size_per_gpu = int(batch_size)
142+
self.max_concurrent_requests = max_concurrent_requests
120143

121144
@property
122145
def batch_size(self):
@@ -164,11 +187,15 @@ def encode_image(self, image: Union[Image.Image, str]):
164187
def encode_video(self, video_path, for_get_frames_num):
165188
vr = VideoReader(video_path, ctx=cpu(0))
166189
total_frame_num = len(vr)
167-
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int)
190+
uniform_sampled_frames = np.linspace(
191+
0, total_frame_num - 1, for_get_frames_num, dtype=int
192+
)
168193

169194
# Ensure the last frame is included
170195
if total_frame_num - 1 not in uniform_sampled_frames:
171-
uniform_sampled_frames = np.append(uniform_sampled_frames, total_frame_num - 1)
196+
uniform_sampled_frames = np.append(
197+
uniform_sampled_frames, total_frame_num - 1
198+
)
172199

173200
frame_idx = uniform_sampled_frames.tolist()
174201
frames = vr.get_batch(frame_idx).asnumpy()
@@ -200,9 +227,15 @@ def _collate(x):
200227

201228
from lmms_eval import utils
202229

203-
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
230+
re_ords = utils.Collator(
231+
[reg.args for reg in requests], _collate, grouping=True
232+
)
204233
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
205-
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
234+
num_iters = (
235+
len(requests) // self.batch_size
236+
if len(requests) % self.batch_size == 0
237+
else len(requests) // self.batch_size + 1
238+
)
206239
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
207240

208241
for chunk in chunks:
@@ -234,10 +267,24 @@ def _collate(x):
234267
visuals = self.flatten(visuals)
235268
imgs = []
236269
for visual in visuals:
237-
if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual):
270+
if isinstance(visual, str) and (
271+
".mp4" in visual
272+
or ".avi" in visual
273+
or ".mov" in visual
274+
or ".flv" in visual
275+
or ".wmv" in visual
276+
):
238277
frames = self.encode_video(visual, self.max_frames_num)
239278
imgs.extend(frames)
240-
elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual):
279+
elif isinstance(visual, str) and (
280+
".jpg" in visual
281+
or ".jpeg" in visual
282+
or ".png" in visual
283+
or ".gif" in visual
284+
or ".bmp" in visual
285+
or ".tiff" in visual
286+
or ".webp" in visual
287+
):
241288
img = self.encode_image(visual)
242289
imgs.append(img)
243290
elif isinstance(visual, Image.Image):
@@ -248,9 +295,16 @@ def _collate(x):
248295
payload["model"] = self.model_version
249296

250297
payload["messages"].append({"role": "user", "content": []})
251-
payload["messages"][0]["content"].append({"type": "text", "text": context})
298+
payload["messages"][0]["content"].append(
299+
{"type": "text", "text": context}
300+
)
252301
for img in imgs:
253-
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
302+
payload["messages"][0]["content"].append(
303+
{
304+
"type": "image_url",
305+
"image_url": {"url": f"data:image/png;base64,{img}"},
306+
}
307+
)
254308

255309
if "max_new_tokens" not in gen_kwargs:
256310
gen_kwargs["max_new_tokens"] = 1024
@@ -288,22 +342,33 @@ def process_single_request(payload, i):
288342

289343
except Exception as e:
290344
error_msg = str(e)
291-
eval_logger.info(f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}")
345+
eval_logger.info(
346+
f"Attempt {attempt + 1}/{self.max_retries} failed with error: {error_msg}"
347+
)
292348

293349
if attempt == self.max_retries - 1:
294-
eval_logger.error(f"All {self.max_retries} attempts failed. Last error: {error_msg}")
350+
eval_logger.error(
351+
f"All {self.max_retries} attempts failed. Last error: {error_msg}"
352+
)
295353
return "", i
296354
else:
297355
time.sleep(self.timeout)
298356

299357
return "", i
300358

301-
tasks_to_run = [(payload, i) for i, payload in enumerate(batch_payloads) if batch_responses[i] is None]
359+
tasks_to_run = [
360+
(payload, i)
361+
for i, payload in enumerate(batch_payloads)
362+
if batch_responses[i] is None
363+
]
302364

303365
if tasks_to_run:
304-
max_workers = min(len(tasks_to_run), 32)
366+
max_workers = min(len(tasks_to_run), self.max_concurrent_requests)
305367
with ThreadPoolExecutor(max_workers=max_workers) as executor:
306-
future_to_index = {executor.submit(process_single_request, payload, i): i for payload, i in tasks_to_run}
368+
future_to_index = {
369+
executor.submit(process_single_request, payload, i): i
370+
for payload, i in tasks_to_run
371+
}
307372

308373
for future in as_completed(future_to_index):
309374
response_text, i = future.result()
@@ -313,17 +378,25 @@ def process_single_request(payload, i):
313378
for doc_uuid, response_text in zip(batch_doc_uuids, batch_responses):
314379
if response_text is not None:
315380
self.response_cache[doc_uuid] = response_text
316-
with open(self.response_persistent_file, "w") as f:
317-
json.dump(self.response_cache, f)
318381

319382
res.extend([r for r in batch_responses if r is not None])
320383
pbar.update(1)
321384

322385
pbar.close()
386+
387+
# Write cache once at the end if in continual mode
388+
if self.continual_mode is True:
389+
with open(self.response_persistent_file, "w") as f:
390+
json.dump(self.response_cache, f)
391+
323392
return res
324393

325394
def generate_until_multi_round(self, requests) -> List[str]:
326-
raise NotImplementedError("TODO: Implement multi-round generation for OpenAI compatible models")
395+
raise NotImplementedError(
396+
"TODO: Implement multi-round generation for OpenAI compatible models"
397+
)
327398

328399
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
329-
raise NotImplementedError("TODO: Implement loglikelihood for OpenAI compatible models")
400+
raise NotImplementedError(
401+
"TODO: Implement loglikelihood for OpenAI compatible models"
402+
)

0 commit comments

Comments
 (0)