Skip to content

Commit b35b057

Browse files
committed
Enable exception logging in _query_endpoint_async
1 parent 7e0c444 commit b35b057

File tree

1 file changed

+173
-112
lines changed
  • multimodal/vl2l/src/mlperf_inference_multimodal_vl2l

1 file changed

+173
-112
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/task.py

Lines changed: 173 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ def __init__(
5656
token=dataset_cli.token,
5757
split="+".join(dataset_cli.split),
5858
)
59-
logger.debug("Loaded {} samples from the dataset splits {}.", len(self.dataset),
60-
dataset_cli.split)
59+
logger.debug(
60+
"Loaded {} samples from the dataset splits {}.",
61+
len(self.dataset),
62+
dataset_cli.split,
63+
)
6164
self.model_cli = model_cli
6265
self.openai_api_client = AsyncOpenAI(
6366
base_url=endpoint_cli.url,
@@ -106,8 +109,7 @@ def _run_event_loop_forever() -> None:
106109

107110
@staticmethod
108111
@abstractmethod
109-
def formulate_messages(
110-
sample: dict[str, Any]) -> list[ChatCompletionMessageParam]:
112+
def formulate_messages(sample: dict[str, Any]) -> list[ChatCompletionMessageParam]:
111113
"""Formulate the messages for chat completion.
112114
113115
Args:
@@ -208,51 +210,71 @@ def _issue_queries(query_samples: list[lg.QuerySample]) -> None:
208210
`lg.QuerySampleIndex` (i.e., the sample index into the dataset).
209211
"""
210212

211-
async def _query_endpoint_async(
212-
query_sample: lg.QuerySample) -> None:
213+
async def _query_endpoint_async(query_sample: lg.QuerySample) -> None:
213214
"""Query the endpoint through the async OpenAI API client."""
214-
messages = self.loaded_messages[query_sample.index]
215-
logger.trace(
216-
"Issuing query sample index: {} with response ID: {}",
217-
query_sample.index,
218-
query_sample.id,
219-
)
220-
tic = time.perf_counter()
221-
response = await self.openai_api_client.chat.completions.create(
222-
model=self.model_cli.repo_id,
223-
messages=messages,
224-
)
225-
logger.trace(
226-
"Received response (ID: {}) from endpoint after {} seconds: {}",
227-
query_sample.id,
228-
time.perf_counter() - tic,
229-
response,
230-
)
231-
content = response.choices[0].message.content
232-
if content is None:
233-
content = ""
234-
bytes_array = array.array("B", content.encode("utf-8"))
235-
address, length = bytes_array.buffer_info()
236-
size_in_bytes = length * bytes_array.itemsize
237-
lg.QuerySamplesComplete(
238-
[
239-
lg.QuerySampleResponse(
240-
query_sample.id,
241-
address,
242-
size_in_bytes,
243-
int(response.usage.completion_tokens),
244-
),
245-
],
246-
)
215+
try:
216+
messages = self.loaded_messages[query_sample.index]
217+
logger.trace(
218+
"Issuing query sample index: {} with response ID: {}",
219+
query_sample.index,
220+
query_sample.id,
221+
)
222+
tic = time.perf_counter()
223+
response = await self.openai_api_client.chat.completions.create(
224+
model=self.model_cli.repo_id,
225+
messages=messages,
226+
)
227+
logger.trace(
228+
"Received response (ID: {}) from endpoint after {} seconds: {}",
229+
query_sample.id,
230+
time.perf_counter() - tic,
231+
response,
232+
)
233+
content = response.choices[0].message.content
234+
if content is None:
235+
content = ""
236+
bytes_array = array.array("B", content.encode("utf-8"))
237+
address, length = bytes_array.buffer_info()
238+
size_in_bytes = length * bytes_array.itemsize
239+
lg.QuerySamplesComplete(
240+
[
241+
lg.QuerySampleResponse(
242+
query_sample.id,
243+
address,
244+
size_in_bytes,
245+
int(response.usage.completion_tokens),
246+
),
247+
],
248+
)
249+
except Exception: # noqa: BLE001
250+
logger.exception(
251+
"Error processing query sample index {} with response ID {}.",
252+
query_sample.index,
253+
query_sample.id,
254+
)
255+
# Send empty response to LoadGen to avoid hanging.
256+
empty_content = ""
257+
bytes_array = array.array("B", empty_content.encode("utf-8"))
258+
address, length = bytes_array.buffer_info()
259+
size_in_bytes = length * bytes_array.itemsize
260+
lg.QuerySamplesComplete(
261+
[
262+
lg.QuerySampleResponse(
263+
query_sample.id,
264+
address,
265+
size_in_bytes,
266+
0,
267+
),
268+
],
269+
)
247270

248271
for query_sample in query_samples:
249272
asyncio.run_coroutine_threadsafe(
250273
_query_endpoint_async(query_sample),
251274
self.event_loop,
252275
)
253276

254-
def _issue_streaming_queries(
255-
query_samples: list[lg.QuerySample]) -> None:
277+
def _issue_streaming_queries(query_samples: list[lg.QuerySample]) -> None:
256278
"""Called by the LoadGen to issue queries to the inference endpoint.
257279
258280
Args:
@@ -262,70 +284,106 @@ def _issue_streaming_queries(
262284
`lg.QuerySampleIndex` (i.e., the sample index into the dataset).
263285
"""
264286

265-
async def _query_endpoint_async(
266-
query_sample: lg.QuerySample) -> None:
287+
async def _query_endpoint_async(query_sample: lg.QuerySample) -> None:
267288
"""Query the endpoint through the async OpenAI API client."""
268-
messages = self.loaded_messages[query_sample.index]
269-
logger.trace(
270-
"Issuing query sample index: {} with response ID: {}",
271-
query_sample.index,
272-
query_sample.id,
273-
)
274289
ttft_set = False
275-
word_array = []
276-
stream = await self.openai_api_client.chat.completions.create(
277-
stream=True,
278-
model=self.model_cli.repo_id,
279-
messages=messages,
280-
stream_options={"include_usage": True},
281-
)
282-
# iterate asynchronously
283-
total_tokens = 0
284-
async for chunk in stream:
285-
286-
# This is the final chunk and will not have 'choices'
287-
if chunk.usage is not None:
288-
total_tokens = int(chunk.usage.completion_tokens)
289-
290-
# If it's not the usage chunk, process it as a content
291-
# chunk
292-
choices = getattr(chunk, "choices", None)
293-
if not choices:
294-
continue
295-
# first non-empty token -> TTFT
296-
delta = choices[0].delta
297-
text = getattr(delta, "content", None)
298-
if not text:
299-
continue
300-
if ttft_set is False:
301-
bytes_array = array.array(
302-
"B", text.encode("utf-8"))
303-
address, length = bytes_array.buffer_info()
304-
size_in_bytes = length * bytes_array.itemsize
305-
lg.FirstTokenComplete([
306-
lg.QuerySampleResponse(query_sample.id,
307-
address,
308-
size_in_bytes,
309-
1),
310-
])
311-
ttft_set = True
312-
word_array.append(text)
313-
314-
# when the stream ends, total latency
315-
content = "".join(word_array)
316-
bytes_array = array.array("B", content.encode("utf-8"))
317-
address, length = bytes_array.buffer_info()
318-
size_in_bytes = length * bytes_array.itemsize
319-
lg.QuerySamplesComplete(
320-
[
321-
lg.QuerySampleResponse(
322-
query_sample.id,
323-
address,
324-
size_in_bytes,
325-
total_tokens,
326-
),
327-
],
328-
)
290+
try:
291+
messages = self.loaded_messages[query_sample.index]
292+
logger.trace(
293+
"Issuing query sample index: {} with response ID: {}",
294+
query_sample.index,
295+
query_sample.id,
296+
)
297+
word_array = []
298+
stream = await self.openai_api_client.chat.completions.create(
299+
stream=True,
300+
model=self.model_cli.repo_id,
301+
messages=messages,
302+
stream_options={"include_usage": True},
303+
)
304+
# iterate asynchronously
305+
total_tokens = 0
306+
async for chunk in stream:
307+
308+
# This is the final chunk and will not have 'choices'
309+
if chunk.usage is not None:
310+
total_tokens = int(chunk.usage.completion_tokens)
311+
312+
# If it's not the usage chunk, process it as a content
313+
# chunk
314+
choices = getattr(chunk, "choices", None)
315+
if not choices:
316+
continue
317+
# first non-empty token -> TTFT
318+
delta = choices[0].delta
319+
text = getattr(delta, "content", None)
320+
if not text:
321+
continue
322+
if ttft_set is False:
323+
bytes_array = array.array("B", text.encode("utf-8"))
324+
address, length = bytes_array.buffer_info()
325+
size_in_bytes = length * bytes_array.itemsize
326+
lg.FirstTokenComplete(
327+
[
328+
lg.QuerySampleResponse(
329+
query_sample.id,
330+
address,
331+
size_in_bytes,
332+
1,
333+
),
334+
],
335+
)
336+
ttft_set = True
337+
word_array.append(text)
338+
339+
# when the stream ends, total latency
340+
content = "".join(word_array)
341+
bytes_array = array.array("B", content.encode("utf-8"))
342+
address, length = bytes_array.buffer_info()
343+
size_in_bytes = length * bytes_array.itemsize
344+
lg.QuerySamplesComplete(
345+
[
346+
lg.QuerySampleResponse(
347+
query_sample.id,
348+
address,
349+
size_in_bytes,
350+
total_tokens,
351+
),
352+
],
353+
)
354+
except Exception: # noqa: BLE001
355+
logger.exception(
356+
"Error processing query sample index {} with response ID {}.",
357+
query_sample.index,
358+
query_sample.id,
359+
)
360+
# Send empty response to LoadGen to avoid hanging.
361+
empty_content = ""
362+
bytes_array = array.array("B", empty_content.encode("utf-8"))
363+
address, length = bytes_array.buffer_info()
364+
size_in_bytes = length * bytes_array.itemsize
365+
# If TTFT was not set, we still need to complete that.
366+
if not ttft_set:
367+
lg.FirstTokenComplete(
368+
[
369+
lg.QuerySampleResponse(
370+
query_sample.id,
371+
address,
372+
size_in_bytes,
373+
0,
374+
),
375+
],
376+
)
377+
lg.QuerySamplesComplete(
378+
[
379+
lg.QuerySampleResponse(
380+
query_sample.id,
381+
address,
382+
size_in_bytes,
383+
0,
384+
),
385+
],
386+
)
329387

330388
for query_sample in query_samples:
331389
asyncio.run_coroutine_threadsafe(
@@ -358,9 +416,14 @@ async def _wait_for_pending_queries_async() -> None:
358416
)
359417
future.result()
360418

361-
return lg.ConstructSUT(_issue_streaming_queries
362-
if self.scenario is TestScenario.SERVER
363-
else _issue_queries, _flush_queries)
419+
return lg.ConstructSUT(
420+
(
421+
_issue_streaming_queries
422+
if self.scenario is TestScenario.SERVER
423+
else _issue_queries
424+
),
425+
_flush_queries,
426+
)
364427

365428

366429
class ShopifyGlobalCatalogue(Task):
@@ -392,8 +455,7 @@ def __init__(
392455
)
393456

394457
@staticmethod
395-
def formulate_messages(
396-
sample: dict[str, Any]) -> list[ChatCompletionMessageParam]:
458+
def formulate_messages(sample: dict[str, Any]) -> list[ChatCompletionMessageParam]:
397459
"""Formulate the messages for chat completion.
398460
399461
Args:
@@ -403,8 +465,8 @@ def formulate_messages(
403465
The messages for chat completion.
404466
"""
405467
image_file = BytesIO()
406-
sample["product_image"].save(
407-
image_file, format=sample["product_image"].format)
468+
image_format = sample["product_image"].format
469+
sample["product_image"].save(image_file, format=image_format)
408470
image_bytes = image_file.getvalue()
409471
image_base64 = base64.b64encode(image_bytes)
410472
image_base64_string = image_base64.decode("utf-8")
@@ -458,8 +520,7 @@ def formulate_messages(
458520
{
459521
"type": "image_url",
460522
"image_url": {
461-
"url":
462-
f"data:image/{sample['product_image'].format};base64,"
523+
"url": f"data:image/{image_format};base64,"
463524
f"{image_base64_string}",
464525
},
465526
},

0 commit comments

Comments
 (0)