Skip to content

Commit 1dc0896

Browse files
committed
Add the ability to pass in extra body params for the OpenAI pathway, similar to the extra query params that were added
1 parent e6f3dfc commit 1dc0896

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/guidellm/backend/openai.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
follow_redirects: Optional[bool] = None,
8585
max_output_tokens: Optional[int] = None,
8686
extra_query: Optional[dict] = None,
87+
extra_body: Optional[dict] = None,
8788
):
8889
super().__init__(type_="openai_http")
8990
self._target = target or settings.openai.base_url
@@ -120,6 +121,7 @@ def __init__(
120121
else settings.openai.max_output_tokens
121122
)
122123
self.extra_query = extra_query
124+
self.extra_body = extra_body
123125
self._async_client: Optional[httpx.AsyncClient] = None
124126

125127
@property
@@ -242,7 +244,9 @@ async def text_completions( # type: ignore[override]
242244

243245
headers = self._headers()
244246
params = self._params(TEXT_COMPLETIONS)
247+
body = self._body(TEXT_COMPLETIONS)
245248
payload = self._completions_payload(
249+
body=body,
246250
orig_kwargs=kwargs,
247251
max_output_tokens=output_token_count,
248252
prompt=prompt,
@@ -317,10 +321,12 @@ async def chat_completions( # type: ignore[override]
317321
logger.debug("{} invocation with args: {}", self.__class__.__name__, locals())
318322
headers = self._headers()
319323
params = self._params(CHAT_COMPLETIONS)
324+
body = self._body(CHAT_COMPLETIONS)
320325
messages = (
321326
content if raw_content else self._create_chat_messages(content=content)
322327
)
323328
payload = self._completions_payload(
329+
body=body,
324330
orig_kwargs=kwargs,
325331
max_output_tokens=output_token_count,
326332
messages=messages,
@@ -396,10 +402,28 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
396402

397403
return self.extra_query
398404

405+
def _body(self, endpoint_type: EndpointType) -> dict[str, str]:
406+
if self.extra_body is None:
407+
return {}
408+
409+
if (
410+
CHAT_COMPLETIONS in self.extra_body
411+
or MODELS in self.extra_body
412+
or TEXT_COMPLETIONS in self.extra_body
413+
):
414+
return self.extra_body.get(endpoint_type, {})
415+
416+
return self.extra_body
417+
399418
def _completions_payload(
400-
self, orig_kwargs: Optional[dict], max_output_tokens: Optional[int], **kwargs
419+
self,
420+
body: Optional[dict],
421+
orig_kwargs: Optional[dict],
422+
max_output_tokens: Optional[int],
423+
**kwargs,
401424
) -> dict:
402-
payload = orig_kwargs or {}
425+
payload = body or {}
426+
payload.update(orig_kwargs or {})
403427
payload.update(kwargs)
404428
payload["model"] = self.model
405429
payload["stream"] = True

src/guidellm/dataset/synthetic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:
201201
class SyntheticDatasetCreator(DatasetCreator):
202202
@classmethod
203203
def is_supported(
204-
cls, data: Any, data_args: Optional[dict[str, Any]] # noqa: ARG003
204+
cls,
205+
data: Any,
206+
data_args: Optional[dict[str, Any]], # noqa: ARG003
205207
) -> bool:
206208
if (
207209
isinstance(data, Path)

0 commit comments

Comments
 (0)