@@ -56,8 +56,11 @@ def __init__(
5656 token = dataset_cli .token ,
5757 split = "+" .join (dataset_cli .split ),
5858 )
59- logger .debug ("Loaded {} samples from the dataset splits {}." , len (self .dataset ),
60- dataset_cli .split )
59+ logger .debug (
60+ "Loaded {} samples from the dataset splits {}." ,
61+ len (self .dataset ),
62+ dataset_cli .split ,
63+ )
6164 self .model_cli = model_cli
6265 self .openai_api_client = AsyncOpenAI (
6366 base_url = endpoint_cli .url ,
@@ -106,8 +109,7 @@ def _run_event_loop_forever() -> None:
106109
107110 @staticmethod
108111 @abstractmethod
109- def formulate_messages (
110- sample : dict [str , Any ]) -> list [ChatCompletionMessageParam ]:
112+ def formulate_messages (sample : dict [str , Any ]) -> list [ChatCompletionMessageParam ]:
111113 """Formulate the messages for chat completion.
112114
113115 Args:
@@ -208,51 +210,71 @@ def _issue_queries(query_samples: list[lg.QuerySample]) -> None:
208210 `lg.QuerySampleIndex` (i.e., the sample index into the dataset).
209211 """
210212
211- async def _query_endpoint_async (
212- query_sample : lg .QuerySample ) -> None :
213+ async def _query_endpoint_async (query_sample : lg .QuerySample ) -> None :
213214 """Query the endpoint through the async OpenAI API client."""
214- messages = self .loaded_messages [query_sample .index ]
215- logger .trace (
216- "Issuing query sample index: {} with response ID: {}" ,
217- query_sample .index ,
218- query_sample .id ,
219- )
220- tic = time .perf_counter ()
221- response = await self .openai_api_client .chat .completions .create (
222- model = self .model_cli .repo_id ,
223- messages = messages ,
224- )
225- logger .trace (
226- "Received response (ID: {}) from endpoint after {} seconds: {}" ,
227- query_sample .id ,
228- time .perf_counter () - tic ,
229- response ,
230- )
231- content = response .choices [0 ].message .content
232- if content is None :
233- content = ""
234- bytes_array = array .array ("B" , content .encode ("utf-8" ))
235- address , length = bytes_array .buffer_info ()
236- size_in_bytes = length * bytes_array .itemsize
237- lg .QuerySamplesComplete (
238- [
239- lg .QuerySampleResponse (
240- query_sample .id ,
241- address ,
242- size_in_bytes ,
243- int (response .usage .completion_tokens ),
244- ),
245- ],
246- )
215+ try :
216+ messages = self .loaded_messages [query_sample .index ]
217+ logger .trace (
218+ "Issuing query sample index: {} with response ID: {}" ,
219+ query_sample .index ,
220+ query_sample .id ,
221+ )
222+ tic = time .perf_counter ()
223+ response = await self .openai_api_client .chat .completions .create (
224+ model = self .model_cli .repo_id ,
225+ messages = messages ,
226+ )
227+ logger .trace (
228+ "Received response (ID: {}) from endpoint after {} seconds: {}" ,
229+ query_sample .id ,
230+ time .perf_counter () - tic ,
231+ response ,
232+ )
233+ content = response .choices [0 ].message .content
234+ if content is None :
235+ content = ""
236+ bytes_array = array .array ("B" , content .encode ("utf-8" ))
237+ address , length = bytes_array .buffer_info ()
238+ size_in_bytes = length * bytes_array .itemsize
239+ lg .QuerySamplesComplete (
240+ [
241+ lg .QuerySampleResponse (
242+ query_sample .id ,
243+ address ,
244+ size_in_bytes ,
245+ int (response .usage .completion_tokens ),
246+ ),
247+ ],
248+ )
249+ except Exception : # noqa: BLE001
250+ logger .exception (
251+ "Error processing query sample index {} with response ID {}." ,
252+ query_sample .index ,
253+ query_sample .id ,
254+ )
255+ # Send empty response to LoadGen to avoid hanging.
256+ empty_content = ""
257+ bytes_array = array .array ("B" , empty_content .encode ("utf-8" ))
258+ address , length = bytes_array .buffer_info ()
259+ size_in_bytes = length * bytes_array .itemsize
260+ lg .QuerySamplesComplete (
261+ [
262+ lg .QuerySampleResponse (
263+ query_sample .id ,
264+ address ,
265+ size_in_bytes ,
266+ 0 ,
267+ ),
268+ ],
269+ )
247270
248271 for query_sample in query_samples :
249272 asyncio .run_coroutine_threadsafe (
250273 _query_endpoint_async (query_sample ),
251274 self .event_loop ,
252275 )
253276
254- def _issue_streaming_queries (
255- query_samples : list [lg .QuerySample ]) -> None :
277+ def _issue_streaming_queries (query_samples : list [lg .QuerySample ]) -> None :
256278 """Called by the LoadGen to issue queries to the inference endpoint.
257279
258280 Args:
@@ -262,70 +284,106 @@ def _issue_streaming_queries(
262284 `lg.QuerySampleIndex` (i.e., the sample index into the dataset).
263285 """
264286
265- async def _query_endpoint_async (
266- query_sample : lg .QuerySample ) -> None :
287+ async def _query_endpoint_async (query_sample : lg .QuerySample ) -> None :
267288 """Query the endpoint through the async OpenAI API client."""
268- messages = self .loaded_messages [query_sample .index ]
269- logger .trace (
270- "Issuing query sample index: {} with response ID: {}" ,
271- query_sample .index ,
272- query_sample .id ,
273- )
274289 ttft_set = False
275- word_array = []
276- stream = await self .openai_api_client .chat .completions .create (
277- stream = True ,
278- model = self .model_cli .repo_id ,
279- messages = messages ,
280- stream_options = {"include_usage" : True },
281- )
282- # iterate asynchronously
283- total_tokens = 0
284- async for chunk in stream :
285-
286- # This is the final chunk and will not have 'choices'
287- if chunk .usage is not None :
288- total_tokens = int (chunk .usage .completion_tokens )
289-
290- # If it's not the usage chunk, process it as a content
291- # chunk
292- choices = getattr (chunk , "choices" , None )
293- if not choices :
294- continue
295- # first non-empty token -> TTFT
296- delta = choices [0 ].delta
297- text = getattr (delta , "content" , None )
298- if not text :
299- continue
300- if ttft_set is False :
301- bytes_array = array .array (
302- "B" , text .encode ("utf-8" ))
303- address , length = bytes_array .buffer_info ()
304- size_in_bytes = length * bytes_array .itemsize
305- lg .FirstTokenComplete ([
306- lg .QuerySampleResponse (query_sample .id ,
307- address ,
308- size_in_bytes ,
309- 1 ),
310- ])
311- ttft_set = True
312- word_array .append (text )
313-
314- # when the stream ends, total latency
315- content = "" .join (word_array )
316- bytes_array = array .array ("B" , content .encode ("utf-8" ))
317- address , length = bytes_array .buffer_info ()
318- size_in_bytes = length * bytes_array .itemsize
319- lg .QuerySamplesComplete (
320- [
321- lg .QuerySampleResponse (
322- query_sample .id ,
323- address ,
324- size_in_bytes ,
325- total_tokens ,
326- ),
327- ],
328- )
290+ try :
291+ messages = self .loaded_messages [query_sample .index ]
292+ logger .trace (
293+ "Issuing query sample index: {} with response ID: {}" ,
294+ query_sample .index ,
295+ query_sample .id ,
296+ )
297+ word_array = []
298+ stream = await self .openai_api_client .chat .completions .create (
299+ stream = True ,
300+ model = self .model_cli .repo_id ,
301+ messages = messages ,
302+ stream_options = {"include_usage" : True },
303+ )
304+ # iterate asynchronously
305+ total_tokens = 0
306+ async for chunk in stream :
307+
308+ # This is the final chunk and will not have 'choices'
309+ if chunk .usage is not None :
310+ total_tokens = int (chunk .usage .completion_tokens )
311+
312+ # If it's not the usage chunk, process it as a content
313+ # chunk
314+ choices = getattr (chunk , "choices" , None )
315+ if not choices :
316+ continue
317+ # first non-empty token -> TTFT
318+ delta = choices [0 ].delta
319+ text = getattr (delta , "content" , None )
320+ if not text :
321+ continue
322+ if ttft_set is False :
323+ bytes_array = array .array ("B" , text .encode ("utf-8" ))
324+ address , length = bytes_array .buffer_info ()
325+ size_in_bytes = length * bytes_array .itemsize
326+ lg .FirstTokenComplete (
327+ [
328+ lg .QuerySampleResponse (
329+ query_sample .id ,
330+ address ,
331+ size_in_bytes ,
332+ 1 ,
333+ ),
334+ ],
335+ )
336+ ttft_set = True
337+ word_array .append (text )
338+
339+ # when the stream ends, total latency
340+ content = "" .join (word_array )
341+ bytes_array = array .array ("B" , content .encode ("utf-8" ))
342+ address , length = bytes_array .buffer_info ()
343+ size_in_bytes = length * bytes_array .itemsize
344+ lg .QuerySamplesComplete (
345+ [
346+ lg .QuerySampleResponse (
347+ query_sample .id ,
348+ address ,
349+ size_in_bytes ,
350+ total_tokens ,
351+ ),
352+ ],
353+ )
354+ except Exception : # noqa: BLE001
355+ logger .exception (
356+ "Error processing query sample index {} with response ID {}." ,
357+ query_sample .index ,
358+ query_sample .id ,
359+ )
360+ # Send empty response to LoadGen to avoid hanging.
361+ empty_content = ""
362+ bytes_array = array .array ("B" , empty_content .encode ("utf-8" ))
363+ address , length = bytes_array .buffer_info ()
364+ size_in_bytes = length * bytes_array .itemsize
365+ # If TTFT was not set, we still need to complete that.
366+ if not ttft_set :
367+ lg .FirstTokenComplete (
368+ [
369+ lg .QuerySampleResponse (
370+ query_sample .id ,
371+ address ,
372+ size_in_bytes ,
373+ 0 ,
374+ ),
375+ ],
376+ )
377+ lg .QuerySamplesComplete (
378+ [
379+ lg .QuerySampleResponse (
380+ query_sample .id ,
381+ address ,
382+ size_in_bytes ,
383+ 0 ,
384+ ),
385+ ],
386+ )
329387
330388 for query_sample in query_samples :
331389 asyncio .run_coroutine_threadsafe (
@@ -358,9 +416,14 @@ async def _wait_for_pending_queries_async() -> None:
358416 )
359417 future .result ()
360418
361- return lg .ConstructSUT (_issue_streaming_queries
362- if self .scenario is TestScenario .SERVER
363- else _issue_queries , _flush_queries )
419+ return lg .ConstructSUT (
420+ (
421+ _issue_streaming_queries
422+ if self .scenario is TestScenario .SERVER
423+ else _issue_queries
424+ ),
425+ _flush_queries ,
426+ )
364427
365428
366429class ShopifyGlobalCatalogue (Task ):
@@ -392,8 +455,7 @@ def __init__(
392455 )
393456
394457 @staticmethod
395- def formulate_messages (
396- sample : dict [str , Any ]) -> list [ChatCompletionMessageParam ]:
458+ def formulate_messages (sample : dict [str , Any ]) -> list [ChatCompletionMessageParam ]:
397459 """Formulate the messages for chat completion.
398460
399461 Args:
@@ -403,8 +465,8 @@ def formulate_messages(
403465 The messages for chat completion.
404466 """
405467 image_file = BytesIO ()
406- sample ["product_image" ].save (
407- image_file , format = sample ["product_image" ].format )
468+ image_format = sample ["product_image" ].format
469+ sample ["product_image" ].save ( image_file , format = image_format )
408470 image_bytes = image_file .getvalue ()
409471 image_base64 = base64 .b64encode (image_bytes )
410472 image_base64_string = image_base64 .decode ("utf-8" )
@@ -458,8 +520,7 @@ def formulate_messages(
458520 {
459521 "type" : "image_url" ,
460522 "image_url" : {
461- "url" :
462- f"data:image/{ sample ['product_image' ].format } ;base64,"
523+ "url" : f"data:image/{ image_format } ;base64,"
463524 f"{ image_base64_string } " ,
464525 },
465526 },
0 commit comments