Skip to content

Commit b3e167f

Browse files
committed
Fixes from reviews and types/style
1 parent 73e711f commit b3e167f

File tree

12 files changed

+123
-90
lines changed

12 files changed

+123
-90
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ dev = [
129129
"mdformat-gfm~=0.3.6",
130130

131131
# type-checking
132+
"pandas-stubs",
132133
"types-PyYAML~=6.0.1",
133134
"types-requests~=2.32.0",
134135
"types-toml",

src/guidellm/backends/openai.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class OpenAIHTTPBackend(Backend):
5454
def __init__(
5555
self,
5656
target: str,
57-
model: str | None = None,
57+
model: str = "",
5858
api_routes: dict[str, str] | None = None,
5959
response_handlers: dict[str, Any] | None = None,
6060
timeout: float = 60.0,
@@ -192,7 +192,7 @@ async def available_models(self) -> list[str]:
192192

193193
return [item["id"] for item in response.json()["data"]]
194194

195-
async def default_model(self) -> str | None:
195+
async def default_model(self) -> str:
196196
"""
197197
Get the default model for this backend.
198198
@@ -202,9 +202,9 @@ async def default_model(self) -> str | None:
202202
return self.model
203203

204204
models = await self.available_models()
205-
return models[0] if models else None
205+
return models[0] if models else ""
206206

207-
async def resolve(
207+
async def resolve( # type: ignore[override]
208208
self,
209209
request: GenerationRequest,
210210
request_info: RequestInfo,
@@ -230,11 +230,9 @@ async def resolve(
230230
if history is not None:
231231
raise NotImplementedError("Multi-turn requests not yet supported")
232232

233-
response_handler = self._resolve_response_handler(
234-
request_type=request.request_type
235-
)
236233
if (request_path := self.api_routes.get(request.request_type)) is None:
237234
raise ValueError(f"Unsupported request type '{request.request_type}'")
235+
238236
request_url = f"{self.target}/{request_path}"
239237
request_files = (
240238
{
@@ -246,6 +244,9 @@ async def resolve(
246244
)
247245
request_json = request.arguments.body if not request_files else None
248246
request_data = request.arguments.body if request_files else None
247+
response_handler = self._resolve_response_handler(
248+
request_type=request.request_type
249+
)
249250

250251
if not request.arguments.stream:
251252
request_info.timings.request_start = time.time()
@@ -288,10 +289,8 @@ async def resolve(
288289
request_info.timings.request_iterations += 1
289290

290291
iterations = response_handler.add_streaming_line(chunk)
291-
if iterations is None or end_reached:
292-
end_reached = True
293-
continue
294-
if iterations <=0:
292+
if iterations is None or iterations <= 0 or end_reached:
293+
end_reached = end_reached or iterations is None
295294
continue
296295

297296
if request_info.timings.first_token_iteration is None:

src/guidellm/backends/response_handlers.py

Lines changed: 69 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from __future__ import annotations
1111

12-
from typing import Any, Protocol
12+
from typing import Any, Protocol, cast
1313

1414
from guidellm.schemas import GenerationRequest, GenerationResponse, UsageMetrics
1515
from guidellm.utils import RegistryMixin, json
@@ -109,14 +109,15 @@ def compile_non_streaming(
109109
:return: Standardized GenerationResponse with extracted text and metrics
110110
"""
111111
choices, usage = self.extract_choices_and_usage(response)
112-
input_metrics, output_metrics = self.extract_metrics(usage)
112+
text = choices[0].get("text", "") if choices else ""
113+
input_metrics, output_metrics = self.extract_metrics(usage, text)
113114

114115
return GenerationResponse(
115116
request_id=request.request_id,
116117
request_args=str(
117118
request.arguments.model_dump() if request.arguments else None
118119
),
119-
text=choices[0].get("text", "") if choices else "",
120+
text=text,
120121
input_metrics=input_metrics,
121122
output_metrics=output_metrics,
122123
)
@@ -137,7 +138,7 @@ def add_streaming_line(self, line: str) -> int | None:
137138
updated = False
138139
choices, usage = self.extract_choices_and_usage(data)
139140

140-
if text := choices[0].get("text"):
141+
if choices and (text := choices[0].get("text")):
141142
self.streaming_texts.append(text)
142143
updated = True
143144

@@ -153,14 +154,15 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse:
153154
:param request: Original generation request
154155
:return: Standardized GenerationResponse with concatenated text and metrics
155156
"""
156-
input_metrics, output_metrics = self.extract_metrics(self.streaming_usage)
157+
text = "".join(self.streaming_texts)
158+
input_metrics, output_metrics = self.extract_metrics(self.streaming_usage, text)
157159

158160
return GenerationResponse(
159161
request_id=request.request_id,
160162
request_args=str(
161163
request.arguments.model_dump() if request.arguments else None
162164
),
163-
text="".join(self.streaming_texts),
165+
text=text,
164166
input_metrics=input_metrics,
165167
output_metrics=output_metrics,
166168
)
@@ -194,25 +196,34 @@ def extract_choices_and_usage(
194196
return response.get("choices", []), response.get("usage", {})
195197

196198
def extract_metrics(
197-
self, usage: dict[str, int | dict[str, int]] | None
199+
self, usage: dict[str, int | dict[str, int]] | None, text: str
198200
) -> tuple[UsageMetrics, UsageMetrics]:
199201
"""
200202
Extract input and output usage metrics from API response usage data.
201203
202204
:param usage: Usage data dictionary from API response
205+
:param text: Generated text for calculating word and character counts
203206
:return: Tuple of input_metrics and output_metrics as UsageMetrics objects
204207
"""
205208
if not usage:
206-
return UsageMetrics(), UsageMetrics()
209+
return UsageMetrics(), UsageMetrics(
210+
text_words=len(text.split()) if text else 0,
211+
text_characters=len(text) if text else 0,
212+
)
207213

208-
input_details: dict[str, int] = usage.get("prompt_tokens_details", {}) or {}
209-
output_details: dict[str, int] = (
210-
usage.get("completion_tokens_details", {}) or {}
214+
input_details: dict[str, int] = cast(
215+
"dict[str, int]", usage.get("prompt_tokens_details", {}) or {}
216+
)
217+
output_details: dict[str, int] = cast(
218+
"dict[str, int]", usage.get("completion_tokens_details", {}) or {}
211219
)
220+
usage_metrics: dict[str, int] = cast("dict[str, int]", usage)
212221

213222
return UsageMetrics(
214223
text_tokens=(
215-
input_details.get("prompt_tokens") or usage.get("prompt_tokens")
224+
input_details.get("prompt_tokens")
225+
or usage_metrics.get("prompt_tokens")
226+
or 0
216227
),
217228
image_tokens=input_details.get("image_tokens"),
218229
video_tokens=input_details.get("video_tokens"),
@@ -221,8 +232,11 @@ def extract_metrics(
221232
), UsageMetrics(
222233
text_tokens=(
223234
output_details.get("completion_tokens")
224-
or usage.get("completion_tokens")
235+
or usage_metrics.get("completion_tokens")
236+
or 0
225237
),
238+
text_words=len(text.split()) if text else 0,
239+
text_characters=len(text) if text else 0,
226240
image_tokens=output_details.get("image_tokens"),
227241
video_tokens=output_details.get("video_tokens"),
228242
audio_tokens=output_details.get("audio_tokens"),
@@ -254,14 +268,16 @@ def compile_non_streaming(
254268
:return: Standardized GenerationResponse with extracted content and metrics
255269
"""
256270
choices, usage = self.extract_choices_and_usage(response)
257-
input_metrics, output_metrics = self.extract_metrics(usage)
271+
choice = choices[0] if choices else {}
272+
text = choice.get("content", "")
273+
input_metrics, output_metrics = self.extract_metrics(usage, text)
258274

259275
return GenerationResponse(
260276
request_id=request.request_id,
261277
request_args=str(
262278
request.arguments.model_dump() if request.arguments else None
263279
),
264-
text=(choices[0].get("message", {}).get("content", "") if choices else ""),
280+
text=text,
265281
input_metrics=input_metrics,
266282
output_metrics=output_metrics,
267283
)
@@ -298,14 +314,15 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse:
298314
:param request: Original generation request
299315
:return: Standardized GenerationResponse with concatenated content and metrics
300316
"""
301-
input_metrics, output_metrics = self.extract_metrics(self.streaming_usage)
317+
text = "".join(self.streaming_texts)
318+
input_metrics, output_metrics = self.extract_metrics(self.streaming_usage, text)
302319

303320
return GenerationResponse(
304321
request_id=request.request_id,
305322
request_args=str(
306323
request.arguments.model_dump() if request.arguments else None
307324
),
308-
text="".join(self.streaming_texts),
325+
text=text,
309326
input_metrics=input_metrics,
310327
output_metrics=output_metrics,
311328
)
@@ -352,29 +369,18 @@ def compile_non_streaming(
352369
:param response: Complete API response containing text and usage data
353370
:return: Standardized GenerationResponse with extracted text and metrics
354371
"""
355-
usage: dict[str, int | dict[str, int]] = response.get("usage", {})
356-
input_details: dict[str, int] = usage.get("input_token_details", {}) or {}
357-
output_details: dict[str, int] = usage.get("output_token_details", {}) or {}
358372
text: str = response.get("text", "")
373+
usage: dict[str, int | dict[str, int]] = response.get("usage", {})
374+
input_metrics, output_metrics = self.extract_metrics(usage, text)
359375

360376
return GenerationResponse(
361377
request_id=request.request_id,
362378
request_args=str(
363379
request.arguments.model_dump() if request.arguments else None
364380
),
365381
text=text,
366-
input_metrics=UsageMetrics(
367-
text_tokens=input_details.get("text_tokens", usage.get("input_tokens")),
368-
audio_tokens=input_details.get(
369-
"audio_tokens", usage.get("input_tokens")
370-
),
371-
audio_seconds=input_details.get("seconds", usage.get("seconds")),
372-
),
373-
output_metrics=UsageMetrics(
374-
text_tokens=output_details.get(
375-
"text_tokens", usage.get("output_tokens")
376-
),
377-
),
382+
input_metrics=input_metrics,
383+
output_metrics=output_metrics,
378384
)
379385

380386
def add_streaming_line(self, line: str) -> int | None:
@@ -394,8 +400,6 @@ def add_streaming_line(self, line: str) -> int | None:
394400
return 0
395401

396402
data: dict[str, Any] = json.loads(line)
397-
text: str
398-
usage: dict[str, int | dict[str, int]]
399403
updated = False
400404

401405
if text := data.get("text"):
@@ -414,20 +418,21 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse:
414418
:param request: Original generation request
415419
:return: Standardized GenerationResponse with concatenated text and metrics
416420
"""
417-
input_metrics, output_metrics = self.extract_metrics(self.streaming_usage)
421+
text = "".join(self.streaming_texts)
422+
input_metrics, output_metrics = self.extract_metrics(self.streaming_usage, text)
418423

419424
return GenerationResponse(
420425
request_id=request.request_id,
421426
request_args=str(
422427
request.arguments.model_dump() if request.arguments else None
423428
),
424-
text="".join(self.streaming_texts),
429+
text=text,
425430
input_metrics=input_metrics,
426431
output_metrics=output_metrics,
427432
)
428433

429434
def extract_metrics(
430-
self, usage: dict[str, int | dict[str, int]] | None
435+
self, usage: dict[str, int | dict[str, int]] | None, text: str
431436
) -> tuple[UsageMetrics, UsageMetrics]:
432437
"""
433438
Extract input and output usage metrics from audio API response usage data.
@@ -436,20 +441,40 @@ def extract_metrics(
436441
in addition to standard text token counts.
437442
438443
:param usage: Usage data dictionary from audio API response
444+
:param text: Generated text for calculating word and character counts
439445
:return: Tuple of input_metrics and output_metrics as UsageMetrics objects
440446
"""
441447
if not usage:
442-
return UsageMetrics(), UsageMetrics()
448+
return UsageMetrics(), UsageMetrics(
449+
text_words=len(text.split()) if text else 0,
450+
text_characters=len(text) if text else 0,
451+
)
443452

444-
input_details: dict[str, int] = usage.get("input_token_details", {}) or {}
445-
output_details: dict[str, int] = usage.get("output_token_details", {}) or {}
453+
input_details: dict[str, int] = cast(
454+
"dict[str, int]", usage.get("input_token_details", {}) or {}
455+
)
456+
output_details: dict[str, int] = cast(
457+
"dict[str, int]", usage.get("output_token_details", {}) or {}
458+
)
459+
usage_metrics: dict[str, int] = cast("dict[str, int]", usage)
446460

447461
return UsageMetrics(
448-
text_tokens=(input_details.get("text_tokens") or usage.get("input_tokens")),
462+
text_tokens=input_details.get("text_tokens") or 0,
449463
audio_tokens=(
450-
input_details.get("audio_tokens") or usage.get("audio_tokens")
464+
input_details.get("audio_tokens")
465+
or usage_metrics.get("audio_tokens")
466+
or usage_metrics.get("input_tokens")
467+
or 0
468+
),
469+
audio_seconds=(
470+
input_details.get("seconds") or usage_metrics.get("seconds") or 0
451471
),
452-
audio_seconds=(input_details.get("seconds") or usage.get("seconds")),
453472
), UsageMetrics(
454-
text_tokens=output_details.get("text_tokens") or usage.get("output_tokens"),
473+
text_tokens=(
474+
output_details.get("text_tokens")
475+
or usage_metrics.get("output_tokens")
476+
or 0
477+
),
478+
text_words=len(text.split()) if text else 0,
479+
text_characters=len(text) if text else 0,
455480
)

src/guidellm/benchmark/schemas/generative/accumulator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def update_estimate(
174174
config.warmup >= 1.0
175175
and scheduler_state.remaining_duration is not None
176176
and self.duration is not None
177-
and config.warmup >= self.duration
177+
and self.duration >= config.warmup
178178
)
179179
exceeded_count = (
180180
config.warmup >= 1.0
@@ -184,7 +184,7 @@ def update_estimate(
184184
exceeded_fraction = (
185185
config.warmup < 1.0
186186
and scheduler_state.remaining_fraction is not None
187-
and config.warmup >= 1.0 - scheduler_state.remaining_fraction
187+
and 1.0 - scheduler_state.remaining_fraction >= config.warmup
188188
)
189189

190190
if exceeded_time or exceeded_count or exceeded_fraction:
@@ -198,7 +198,7 @@ def update_estimate(
198198
exceeded_time = (
199199
config.cooldown >= 1.0
200200
and scheduler_state.remaining_duration is not None
201-
and config.cooldown <= scheduler_state.remaining_duration
201+
and scheduler_state.remaining_duration <= config.cooldown
202202
)
203203
exceeded_count = (
204204
config.cooldown >= 1.0
@@ -208,7 +208,7 @@ def update_estimate(
208208
exceeded_fraction = (
209209
config.cooldown < 1.0
210210
and scheduler_state.remaining_fraction is not None
211-
and config.cooldown >= scheduler_state.remaining_fraction
211+
and scheduler_state.remaining_fraction <= config.cooldown
212212
)
213213

214214
if exceeded_time or exceeded_count or exceeded_fraction:
@@ -401,7 +401,7 @@ def update_estimate(
401401
+ scheduler_state.cancelled_requests
402402
)
403403

404-
# All requests much have queued, dequeued, resolve_end, and finalized timings
404+
# All requests must have queued, dequeued, resolve_end, and finalized timings
405405
timings: RequestTimings = stats.info.timings
406406
if any(
407407
timing is None

src/guidellm/benchmark/schemas/generative/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ def compile(cls, accumulator: GenerativeBenchmarkAccumulator) -> GenerativeMetri
882882
errored=errored,
883883
),
884884
prompt_tokens_per_second=StatusDistributionSummary.rate_distribution_from_timings_function(
885-
function=lambda req: req.prompt_tokens_timings,
885+
function=lambda req: req.prompt_tokens_timing,
886886
successful=successful,
887887
incomplete=incomplete,
888888
errored=errored,

0 commit comments

Comments
 (0)