4
4
import litellm
5
5
from litellm import token_counter
6
6
from collections import defaultdict
7
- from typing import Any , Dict , List , Type , Optional , Tuple
7
+ from typing import Any , Dict , List , Type , Optional , Tuple , DefaultDict
8
8
from pydantic import BaseModel , Field
9
9
from pydantic_core import to_jsonable_python
10
10
from litellm .types .utils import Usage
11
11
from litellm .utils import trim_messages
12
12
from tenacity import Retrying , AsyncRetrying
13
13
from instructor .exceptions import InstructorRetryException , IncompleteOutputException
14
14
from instructor .client import Instructor , AsyncInstructor
15
- from adala .utils .parse import MessagesBuilder , MessageChunkType
15
+ from adala .utils .parse import MessageChunkType
16
+ from adala .utils .message_builder import MessagesBuilder
16
17
from adala .utils .exceptions import ConstrainedGenerationError
17
18
from adala .utils .types import debug_time_it
18
19
from litellm .exceptions import BadRequestError
19
20
20
21
logger = logging .getLogger (__name__ )
21
22
22
23
23
- def _get_usage_dict (usage : Usage , model : str ) -> Dict :
24
+ def _count_message_content (
25
+ message : Dict [str , Any ], counts : DefaultDict [str , int ]
26
+ ) -> None :
27
+ """Helper method to count different content types in a message."""
28
+ if "role" in message and "content" in message :
29
+ content = message ["content" ]
30
+ if isinstance (content , str ):
31
+ counts ["text" ] += 1
32
+ elif isinstance (content , list ):
33
+ for content_part in content :
34
+ if isinstance (content_part , dict ) and "type" in content_part :
35
+ counts [content_part ["type" ]] += 1
36
+ else :
37
+ counts ["text" ] += 1
38
+ elif "type" in message :
39
+ counts [message ["type" ]] += 1
40
+ else :
41
+ counts ["text" ] += 1
42
+
43
+
44
+ def count_message_types (messages : List [Dict [str , Any ]]) -> Dict [str , int ]:
45
+ """
46
+ Count the number of each message type in a list of messages.
47
+
48
+ Args:
49
+ messages: List of message dictionaries
50
+
51
+ Returns:
52
+ Dictionary mapping message types to counts
53
+ """
54
+ message_counts : DefaultDict [str , int ] = defaultdict (int )
55
+
56
+ for message in messages :
57
+ _count_message_content (message , message_counts )
58
+
59
+ return dict (message_counts )
60
+
61
+
62
+ def _get_usage_dict (usage : Usage , model : str , messages : List [Dict [str , Any ]]) -> Dict :
24
63
data = dict ()
25
64
data ["_prompt_tokens" ] = usage .prompt_tokens
26
65
@@ -47,6 +86,7 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
47
86
data ["_prompt_cost_usd" ] = None
48
87
data ["_completion_cost_usd" ] = None
49
88
data ["_total_cost_usd" ] = None
89
+ data ["_message_counts" ] = count_message_types (messages )
50
90
return data
51
91
52
92
@@ -141,34 +181,6 @@ def handle_llm_exception(
141
181
return _log_llm_exception (e ), usage
142
182
143
183
144
- def _ensure_messages_fit_in_context_window (
145
- messages : List [Dict [str , str ]], model : str
146
- ) -> Tuple [List [Dict [str , str ]], int ]:
147
- """
148
- Ensure that the messages fit in the context window of the model.
149
- """
150
- token_count = token_counter (model = model , messages = messages )
151
- logger .debug (f"Prompt tokens count: { token_count } " )
152
-
153
- if model in litellm .model_cost :
154
- # If we are able to identify the model context window, ensure the messages fit in it
155
- max_tokens = litellm .model_cost [model ].get (
156
- "max_input_tokens" , litellm .model_cost [model ]["max_tokens" ]
157
- )
158
- if token_count > max_tokens :
159
- logger .info (
160
- f"Prompt tokens count { token_count } exceeds max tokens { max_tokens } for model { model } . Trimming messages."
161
- )
162
- # TODO: in case it exceeds max tokens, content of the last message is truncated.
163
- # to improve this, we implement:
164
- # - UI-level warning for the user, use prediction_meta field for warnings as well as errors in future
165
- # - sequential aggregation instead of trimming
166
- # - potential v2 solution to downsample images instead of cutting them off (using quality=low instead of quality=auto in completion)
167
- return trim_messages (messages , model = model ), token_count
168
- # in other cases, just return the original messages
169
- return messages , token_count
170
-
171
-
172
184
@debug_time_it
173
185
def run_instructor_with_messages (
174
186
client : Instructor ,
@@ -203,11 +215,6 @@ def run_instructor_with_messages(
203
215
Dict containing the parsed response and usage information
204
216
"""
205
217
try :
206
- prompt_token_count = None
207
- if ensure_messages_fit_in_context_window :
208
- messages , prompt_token_count = _ensure_messages_fit_in_context_window (
209
- messages , canonical_model_provider_string or model
210
- )
211
218
212
219
response , completion = client .chat .completions .create_with_completion (
213
220
messages = messages ,
@@ -225,14 +232,15 @@ def run_instructor_with_messages(
225
232
usage_model = completion .model
226
233
227
234
except Exception as e :
228
- dct , usage = handle_llm_exception (
229
- e , messages , model , retries , prompt_token_count = prompt_token_count
230
- )
235
+ dct , usage = handle_llm_exception (e , messages , model , retries )
231
236
# With exceptions we don't have access to completion.model
232
237
usage_model = canonical_model_provider_string or model
238
+ # Add empty message counts in case of exception
233
239
234
240
# Add usage data to the response (e.g. token counts, cost)
235
- dct .update (_get_usage_dict (usage , model = usage_model ))
241
+ usage_data = _get_usage_dict (usage , model = usage_model , messages = messages )
242
+ # Add message counts to usage data
243
+ dct .update (usage_data )
236
244
237
245
return dct
238
246
@@ -248,7 +256,6 @@ async def arun_instructor_with_messages(
248
256
temperature : Optional [float ] = None ,
249
257
seed : Optional [int ] = None ,
250
258
retries : Optional [AsyncRetrying ] = None ,
251
- ensure_messages_fit_in_context_window : bool = False ,
252
259
** kwargs ,
253
260
) -> Dict [str , Any ]:
254
261
"""
@@ -264,18 +271,12 @@ async def arun_instructor_with_messages(
264
271
temperature: Temperature for sampling
265
272
seed: Integer seed to reduce nondeterminism
266
273
retries: Retry policy to use
267
- ensure_messages_fit_in_context_window: Whether to ensure the messages fit in the context window (setting it to True will slow down the function)
268
274
**kwargs: Additional arguments to pass to the completion call
269
275
270
276
Returns:
271
277
Dict containing the parsed response and usage information
272
278
"""
273
279
try :
274
- prompt_token_count = None
275
- if ensure_messages_fit_in_context_window :
276
- messages , prompt_token_count = _ensure_messages_fit_in_context_window (
277
- messages , canonical_model_provider_string or model
278
- )
279
280
280
281
response , completion = await client .chat .completions .create_with_completion (
281
282
messages = messages ,
@@ -293,14 +294,13 @@ async def arun_instructor_with_messages(
293
294
usage_model = completion .model
294
295
295
296
except Exception as e :
296
- dct , usage = handle_llm_exception (
297
- e , messages , model , retries , prompt_token_count = prompt_token_count
298
- )
297
+ dct , usage = handle_llm_exception (e , messages , model , retries )
299
298
# With exceptions we don't have access to completion.model
300
299
usage_model = canonical_model_provider_string or model
301
300
302
301
# Add usage data to the response (e.g. token counts, cost)
303
- dct .update (_get_usage_dict (usage , model = usage_model ))
302
+ usage_data = _get_usage_dict (usage , model = usage_model , messages = messages )
303
+ dct .update (usage_data )
304
304
305
305
return dct
306
306
@@ -359,7 +359,7 @@ def run_instructor_with_payload(
359
359
split_into_chunks = split_into_chunks ,
360
360
)
361
361
362
- messages = messages_builder .get_messages (payload )
362
+ messages = messages_builder .get_messages (payload ). messages
363
363
return run_instructor_with_messages (
364
364
client ,
365
365
messages ,
@@ -428,7 +428,7 @@ async def arun_instructor_with_payload(
428
428
split_into_chunks = split_into_chunks ,
429
429
)
430
430
431
- messages = messages_builder .get_messages (payload )
431
+ messages = messages_builder .get_messages (payload ). messages
432
432
return await arun_instructor_with_messages (
433
433
client ,
434
434
messages ,
@@ -500,7 +500,7 @@ def run_instructor_with_payloads(
500
500
501
501
results = []
502
502
for payload in payloads :
503
- messages = messages_builder .get_messages (payload )
503
+ messages = messages_builder .get_messages (payload ). messages
504
504
result = run_instructor_with_messages (
505
505
client ,
506
506
messages ,
@@ -571,23 +571,26 @@ async def arun_instructor_with_payloads(
571
571
input_field_types = input_field_types ,
572
572
extra_fields = extra_fields ,
573
573
split_into_chunks = split_into_chunks ,
574
+ trim_to_fit_context = ensure_messages_fit_in_context_window ,
575
+ model = canonical_model_provider_string or model ,
574
576
)
575
577
576
- tasks = [
577
- arun_instructor_with_messages (
578
- client ,
579
- messages_builder .get_messages (payload ),
580
- response_model ,
581
- model ,
582
- canonical_model_provider_string ,
583
- max_tokens ,
584
- temperature ,
585
- seed ,
586
- retries ,
587
- ensure_messages_fit_in_context_window = ensure_messages_fit_in_context_window ,
588
- ** kwargs ,
578
+ tasks = []
579
+ for payload in payloads :
580
+ messages = messages_builder .get_messages (payload ).messages
581
+ tasks .append (
582
+ arun_instructor_with_messages (
583
+ client ,
584
+ messages ,
585
+ response_model ,
586
+ model ,
587
+ canonical_model_provider_string ,
588
+ max_tokens ,
589
+ temperature ,
590
+ seed ,
591
+ retries ,
592
+ ** kwargs ,
593
+ )
589
594
)
590
- for payload in payloads
591
- ]
592
595
593
596
return await asyncio .gather (* tasks )
0 commit comments