Skip to content

Commit 922681a

Browse files
committed
Fixed tests and made adjustments
1 parent d121c18 commit 922681a

File tree

6 files changed

+167
-43
lines changed

6 files changed

+167
-43
lines changed

adalflow/adalflow/components/model_client/openai_client.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ class OpenAIClient(ModelClient):
110110
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
111111
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
112112
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
113-
model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
114113
115114
Note:
116115
We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API.
@@ -142,15 +141,13 @@ def __init__(
142141
api_key: Optional[str] = None,
143142
chat_completion_parser: Callable[[Completion], Any] = None,
144143
input_type: Literal["text", "messages"] = "text",
145-
model_type: ModelType = ModelType.LLM,
146144
):
147145
r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.
148146
149147
Args:
150148
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
151149
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
152150
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
153-
model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
154151
"""
155152
super().__init__()
156153
self._api_key = api_key
@@ -160,7 +157,6 @@ def __init__(
160157
chat_completion_parser or get_first_message_content
161158
)
162159
self._input_type = input_type
163-
self.model_type = model_type
164160

165161
def init_sync_client(self):
166162
api_key = self._api_key or os.getenv("OPENAI_API_KEY")
@@ -235,6 +231,7 @@ def convert_inputs_to_api_kwargs(
235231
self,
236232
input: Optional[Any] = None,
237233
model_kwargs: Dict = {},
234+
model_type: ModelType = ModelType.UNDEFINED, # Now required in practice
238235
) -> Dict:
239236
r"""
240237
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
@@ -259,20 +256,23 @@ def convert_inputs_to_api_kwargs(
259256
- mask: Path to the mask image
260257
For variations (DALL-E 2 only):
261258
- image: Path to the input image
259+
model_type: The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Required.
262260
263261
Returns:
264262
Dict: API-specific kwargs for the model call
265263
"""
264+
if model_type == ModelType.UNDEFINED:
265+
raise ValueError("model_type must be specified")
266266

267267
final_model_kwargs = model_kwargs.copy()
268-
if self.model_type == ModelType.EMBEDDER:
268+
if model_type == ModelType.EMBEDDER:
269269
if isinstance(input, str):
270270
input = [input]
271271
# convert input to input
272272
if not isinstance(input, Sequence):
273273
raise TypeError("input must be a sequence of text")
274274
final_model_kwargs["input"] = input
275-
elif self.model_type == ModelType.LLM:
275+
elif model_type == ModelType.LLM:
276276
# convert input to messages
277277
messages: List[Dict[str, str]] = []
278278
images = final_model_kwargs.pop("images", None)
@@ -317,7 +317,7 @@ def convert_inputs_to_api_kwargs(
317317
else:
318318
messages.append({"role": "system", "content": input})
319319
final_model_kwargs["messages"] = messages
320-
elif self.model_type == ModelType.IMAGE_GENERATION:
320+
elif model_type == ModelType.IMAGE_GENERATION:
321321
# For image generation, input is the prompt
322322
final_model_kwargs["prompt"] = input
323323
# Ensure model is specified
@@ -362,7 +362,7 @@ def convert_inputs_to_api_kwargs(
362362
else:
363363
raise ValueError(f"Invalid operation: {operation}")
364364
else:
365-
raise ValueError(f"model_type {self.model_type} is not supported")
365+
raise ValueError(f"model_type {model_type} is not supported")
366366
return final_model_kwargs
367367

368368
def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput:
@@ -379,11 +379,7 @@ def parse_image_generation_response(self, response: List[Image]) -> GeneratorOut
379379
)
380380
except Exception as e:
381381
log.error(f"Error parsing image generation response: {e}")
382-
return GeneratorOutput(
383-
data=None,
384-
error=str(e),
385-
raw_response=str(response)
386-
)
382+
return GeneratorOutput(data=None, error=str(e), raw_response=str(response))
387383

388384
@backoff.on_exception(
389385
backoff.expo,
@@ -400,6 +396,9 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
400396
"""
401397
kwargs is the combined input and model_kwargs. Support streaming call.
402398
"""
399+
if model_type == ModelType.UNDEFINED:
400+
raise ValueError("model_type must be specified")
401+
403402
log.info(f"api_kwargs: {api_kwargs}")
404403
if model_type == ModelType.EMBEDDER:
405404
return self.sync_client.embeddings.create(**api_kwargs)
@@ -449,6 +448,9 @@ async def acall(
449448
"""
450449
kwargs is the combined input and model_kwargs
451450
"""
451+
if model_type == ModelType.UNDEFINED:
452+
raise ValueError("model_type must be specified")
453+
452454
if self.async_client is None:
453455
self.async_client = self.init_async_client()
454456
if model_type == ModelType.EMBEDDER:

adalflow/adalflow/core/generator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class Generator(GradComponent, CachedEngine, CallbackManager):
7373
name (Optional[str], optional): The name of the generator. Defaults to None.
7474
cache_path (Optional[str], optional): The path to save the cache. Defaults to None.
7575
use_cache (bool, optional): Whether to use cache. Defaults to False.
76+
model_type (ModelType, optional): The type of the model. Defaults to ModelType.LLM.
7677
"""
7778

7879
def __init__(
@@ -90,6 +91,7 @@ def __init__(
9091
# args for the cache
9192
cache_path: Optional[str] = None,
9293
use_cache: bool = False,
94+
model_type: ModelType = ModelType.LLM, # Add model_type parameter with default
9395
) -> None:
9496
r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables:
9597
- task_desc_str
@@ -122,7 +124,7 @@ def __init__(
122124
CallbackManager.__init__(self)
123125

124126
self.name = name or self.__class__.__name__
125-
self.model_type = model_client.model_type # Get model type from client
127+
self.model_type = model_type # Use the passed model_type instead of getting from client
126128

127129
self._init_prompt(template, prompt_kwargs)
128130

@@ -326,6 +328,7 @@ def _pre_call(self, prompt_kwargs: Dict, model_kwargs: Dict) -> Dict[str, Any]:
326328
api_kwargs = self.model_client.convert_inputs_to_api_kwargs(
327329
input=prompt_str,
328330
model_kwargs=composed_model_kwargs,
331+
model_type=self.model_type,
329332
)
330333
return api_kwargs
331334

adalflow/tests/test_generator.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from adalflow.core.model_client import ModelClient
1616
from adalflow.components.model_client.groq_client import GroqAPIClient
1717
from adalflow.tracing import GeneratorStateLogger
18+
from adalflow.core.types import ModelType
1819

1920

2021
class TestGenerator(IsolatedAsyncioTestCase):
@@ -32,7 +33,7 @@ def setUp(self):
3233
)
3334
self.mock_api_client = mock_api_client
3435

35-
self.generator = Generator(model_client=mock_api_client)
36+
self.generator = Generator(model_client=mock_api_client, model_type=ModelType.LLM)
3637
self.save_dir = "./tests/log"
3738
self.project_name = "TestGenerator"
3839
self.filename = "prompt_logger_test.json"
@@ -182,7 +183,7 @@ def test_groq_client_call(self, mock_call):
182183
template = "Hello, {{ input_str }}!"
183184

184185
# Initialize the Generator with the mocked client
185-
generator = Generator(model_client=self.client, template=template)
186+
generator = Generator(model_client=self.client, template=template, model_type=ModelType.LLM)
186187

187188
# Call the generator and get the output
188189
output = generator.call(prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs)
+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
Multimodal Client Tutorial
2+
=======================
3+
4+
This tutorial demonstrates how to use the OpenAI client for different types of tasks: text generation, vision analysis, and image generation.
5+
6+
Model Types
7+
----------
8+
9+
The OpenAI client supports three types of operations:
10+
11+
1. Text/Chat Completion (``ModelType.LLM``)
12+
- Standard text generation
13+
- Vision analysis (with GPT-4V)
14+
2. Image Generation (``ModelType.IMAGE_GENERATION``)
15+
- DALL-E image generation
16+
3. Embeddings (``ModelType.EMBEDDER``)
17+
- Text embeddings
18+
19+
Basic Usage
20+
----------
21+
22+
The model type is specified when creating a ``Generator`` instance:
23+
24+
.. code-block:: python
25+
26+
from adalflow.core import Generator
27+
from adalflow.components.model_client.openai_client import OpenAIClient
28+
from adalflow.core.types import ModelType
29+
30+
# Create the client
31+
client = OpenAIClient()
32+
33+
# For text generation
34+
gen = Generator(
35+
model_client=client,
36+
model_kwargs={"model": "gpt-4", "max_tokens": 100},
37+
model_type=ModelType.LLM # Specify LLM type
38+
)
39+
response = gen({"input_str": "Hello, world!"})
40+
41+
Vision Tasks
42+
-----------
43+
44+
Vision tasks use ``ModelType.LLM`` since they are handled by GPT-4V:
45+
46+
.. code-block:: python
47+
48+
# Vision analysis
49+
vision_gen = Generator(
50+
model_client=client,
51+
model_kwargs={
52+
"model": "gpt-4o-mini",
53+
"images": "path/to/image.jpg",
54+
"max_tokens": 300,
55+
},
56+
model_type=ModelType.LLM # Vision uses LLM type
57+
)
58+
response = vision_gen({"input_str": "What do you see in this image?"})
59+
60+
Image Generation
61+
--------------
62+
63+
For DALL-E image generation, use ``ModelType.IMAGE_GENERATION``:
64+
65+
.. code-block:: python
66+
67+
# Image generation with DALL-E
68+
dalle_gen = Generator(
69+
model_client=client,
70+
model_kwargs={
71+
"model": "dall-e-3",
72+
"size": "1024x1024",
73+
"quality": "standard",
74+
"n": 1,
75+
},
76+
model_type=ModelType.IMAGE_GENERATION # Specify image generation type
77+
)
78+
response = dalle_gen({"input_str": "A cat playing with yarn"})
79+
80+
Backward Compatibility
81+
--------------------
82+
83+
For backward compatibility with existing code:
84+
85+
1. ``model_type`` defaults to ``ModelType.LLM`` if not specified
86+
2. Older models that only support text continue to work with ``ModelType.LLM``
87+
3. The OpenAI client handles the appropriate API endpoints based on the model type
88+
89+
Error Handling
90+
-------------
91+
92+
The client includes error handling for:
93+
94+
1. Invalid model types for operations
95+
2. Invalid image URLs or file paths
96+
3. Unsupported model capabilities
97+
4. API errors and rate limits
98+
99+
Complete Example
100+
--------------
101+
102+
See the complete example in ``tutorials/multimodal_client_testing_examples.py``, which demonstrates:
103+
104+
1. Basic text generation
105+
2. Vision analysis with image input
106+
3. DALL-E image generation
107+
4. Error handling for invalid inputs

tests/test_generator.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

0 commit comments

Comments
 (0)