@@ -110,7 +110,6 @@ class OpenAIClient(ModelClient):
110
110
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
111
111
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
112
112
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
113
- model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
114
113
115
114
Note:
116
115
We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API.
@@ -142,15 +141,13 @@ def __init__(
142
141
api_key : Optional [str ] = None ,
143
142
chat_completion_parser : Callable [[Completion ], Any ] = None ,
144
143
input_type : Literal ["text" , "messages" ] = "text" ,
145
- model_type : ModelType = ModelType .LLM ,
146
144
):
147
145
r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.
148
146
149
147
Args:
150
148
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
151
149
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
152
150
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
153
- model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
154
151
"""
155
152
super ().__init__ ()
156
153
self ._api_key = api_key
@@ -160,7 +157,6 @@ def __init__(
160
157
chat_completion_parser or get_first_message_content
161
158
)
162
159
self ._input_type = input_type
163
- self .model_type = model_type
164
160
165
161
def init_sync_client (self ):
166
162
api_key = self ._api_key or os .getenv ("OPENAI_API_KEY" )
@@ -235,6 +231,7 @@ def convert_inputs_to_api_kwargs(
235
231
self ,
236
232
input : Optional [Any ] = None ,
237
233
model_kwargs : Dict = {},
234
+ model_type : ModelType = ModelType .UNDEFINED , # Now required in practice
238
235
) -> Dict :
239
236
r"""
240
237
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
@@ -259,20 +256,23 @@ def convert_inputs_to_api_kwargs(
259
256
- mask: Path to the mask image
260
257
For variations (DALL-E 2 only):
261
258
- image: Path to the input image
259
+ model_type: The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Required.
262
260
263
261
Returns:
264
262
Dict: API-specific kwargs for the model call
265
263
"""
264
+ if model_type == ModelType .UNDEFINED :
265
+ raise ValueError ("model_type must be specified" )
266
266
267
267
final_model_kwargs = model_kwargs .copy ()
268
- if self . model_type == ModelType .EMBEDDER :
268
+ if model_type == ModelType .EMBEDDER :
269
269
if isinstance (input , str ):
270
270
input = [input ]
271
271
# convert input to input
272
272
if not isinstance (input , Sequence ):
273
273
raise TypeError ("input must be a sequence of text" )
274
274
final_model_kwargs ["input" ] = input
275
- elif self . model_type == ModelType .LLM :
275
+ elif model_type == ModelType .LLM :
276
276
# convert input to messages
277
277
messages : List [Dict [str , str ]] = []
278
278
images = final_model_kwargs .pop ("images" , None )
@@ -317,7 +317,7 @@ def convert_inputs_to_api_kwargs(
317
317
else :
318
318
messages .append ({"role" : "system" , "content" : input })
319
319
final_model_kwargs ["messages" ] = messages
320
- elif self . model_type == ModelType .IMAGE_GENERATION :
320
+ elif model_type == ModelType .IMAGE_GENERATION :
321
321
# For image generation, input is the prompt
322
322
final_model_kwargs ["prompt" ] = input
323
323
# Ensure model is specified
@@ -362,7 +362,7 @@ def convert_inputs_to_api_kwargs(
362
362
else :
363
363
raise ValueError (f"Invalid operation: { operation } " )
364
364
else :
365
- raise ValueError (f"model_type { self . model_type } is not supported" )
365
+ raise ValueError (f"model_type { model_type } is not supported" )
366
366
return final_model_kwargs
367
367
368
368
def parse_image_generation_response (self , response : List [Image ]) -> GeneratorOutput :
@@ -379,11 +379,7 @@ def parse_image_generation_response(self, response: List[Image]) -> GeneratorOut
379
379
)
380
380
except Exception as e :
381
381
log .error (f"Error parsing image generation response: { e } " )
382
- return GeneratorOutput (
383
- data = None ,
384
- error = str (e ),
385
- raw_response = str (response )
386
- )
382
+ return GeneratorOutput (data = None , error = str (e ), raw_response = str (response ))
387
383
388
384
@backoff .on_exception (
389
385
backoff .expo ,
@@ -400,6 +396,9 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
400
396
"""
401
397
kwargs is the combined input and model_kwargs. Support streaming call.
402
398
"""
399
+ if model_type == ModelType .UNDEFINED :
400
+ raise ValueError ("model_type must be specified" )
401
+
403
402
log .info (f"api_kwargs: { api_kwargs } " )
404
403
if model_type == ModelType .EMBEDDER :
405
404
return self .sync_client .embeddings .create (** api_kwargs )
@@ -449,6 +448,9 @@ async def acall(
449
448
"""
450
449
kwargs is the combined input and model_kwargs
451
450
"""
451
+ if model_type == ModelType .UNDEFINED :
452
+ raise ValueError ("model_type must be specified" )
453
+
452
454
if self .async_client is None :
453
455
self .async_client = self .init_async_client ()
454
456
if model_type == ModelType .EMBEDDER :
0 commit comments