@@ -84,6 +84,7 @@ def __init__(
84
84
follow_redirects : Optional [bool ] = None ,
85
85
max_output_tokens : Optional [int ] = None ,
86
86
extra_query : Optional [dict ] = None ,
87
+ extra_body : Optional [dict ] = None ,
87
88
):
88
89
super ().__init__ (type_ = "openai_http" )
89
90
self ._target = target or settings .openai .base_url
@@ -120,6 +121,7 @@ def __init__(
120
121
else settings .openai .max_output_tokens
121
122
)
122
123
self .extra_query = extra_query
124
+ self .extra_body = extra_body
123
125
self ._async_client : Optional [httpx .AsyncClient ] = None
124
126
125
127
@property
@@ -242,7 +244,9 @@ async def text_completions( # type: ignore[override]
242
244
243
245
headers = self ._headers ()
244
246
params = self ._params (TEXT_COMPLETIONS )
247
+ body = self ._body (TEXT_COMPLETIONS )
245
248
payload = self ._completions_payload (
249
+ body = body ,
246
250
orig_kwargs = kwargs ,
247
251
max_output_tokens = output_token_count ,
248
252
prompt = prompt ,
@@ -317,10 +321,12 @@ async def chat_completions( # type: ignore[override]
317
321
logger .debug ("{} invocation with args: {}" , self .__class__ .__name__ , locals ())
318
322
headers = self ._headers ()
319
323
params = self ._params (CHAT_COMPLETIONS )
324
+ body = self ._body (CHAT_COMPLETIONS )
320
325
messages = (
321
326
content if raw_content else self ._create_chat_messages (content = content )
322
327
)
323
328
payload = self ._completions_payload (
329
+ body = body ,
324
330
orig_kwargs = kwargs ,
325
331
max_output_tokens = output_token_count ,
326
332
messages = messages ,
@@ -396,10 +402,28 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
396
402
397
403
return self .extra_query
398
404
405
+ def _body (self , endpoint_type : EndpointType ) -> dict [str , str ]:
406
+ if self .extra_body is None :
407
+ return {}
408
+
409
+ if (
410
+ CHAT_COMPLETIONS in self .extra_body
411
+ or MODELS in self .extra_body
412
+ or TEXT_COMPLETIONS in self .extra_body
413
+ ):
414
+ return self .extra_body .get (endpoint_type , {})
415
+
416
+ return self .extra_body
417
+
399
418
def _completions_payload (
400
- self , orig_kwargs : Optional [dict ], max_output_tokens : Optional [int ], ** kwargs
419
+ self ,
420
+ body : Optional [dict ],
421
+ orig_kwargs : Optional [dict ],
422
+ max_output_tokens : Optional [int ],
423
+ ** kwargs ,
401
424
) -> dict :
402
- payload = orig_kwargs or {}
425
+ payload = body or {}
426
+ payload .update (orig_kwargs or {})
403
427
payload .update (kwargs )
404
428
payload ["model" ] = self .model
405
429
payload ["stream" ] = True
0 commit comments