44import litellm
55from litellm import token_counter
66from collections import defaultdict
7- from typing import Any , Dict , List , Type , Optional , Tuple
7+ from typing import Any , Dict , List , Type , Optional , Tuple , DefaultDict
88from pydantic import BaseModel , Field
99from pydantic_core import to_jsonable_python
1010from litellm .types .utils import Usage
1111from litellm .utils import trim_messages
1212from tenacity import Retrying , AsyncRetrying
1313from instructor .exceptions import InstructorRetryException , IncompleteOutputException
1414from instructor .client import Instructor , AsyncInstructor
15- from adala .utils .parse import MessagesBuilder , MessageChunkType
15+ from adala .utils .parse import MessageChunkType
16+ from adala .utils .message_builder import MessagesBuilder
1617from adala .utils .exceptions import ConstrainedGenerationError
1718from adala .utils .types import debug_time_it
1819from litellm .exceptions import BadRequestError
1920
2021logger = logging .getLogger (__name__ )
2122
2223
23- def _get_usage_dict (usage : Usage , model : str ) -> Dict :
24+ def _count_message_content (
25+ message : Dict [str , Any ], counts : DefaultDict [str , int ]
26+ ) -> None :
27+ """Helper method to count different content types in a message."""
28+ if "role" in message and "content" in message :
29+ content = message ["content" ]
30+ if isinstance (content , str ):
31+ counts ["text" ] += 1
32+ elif isinstance (content , list ):
33+ for content_part in content :
34+ if isinstance (content_part , dict ) and "type" in content_part :
35+ counts [content_part ["type" ]] += 1
36+ else :
37+ counts ["text" ] += 1
38+ elif "type" in message :
39+ counts [message ["type" ]] += 1
40+ else :
41+ counts ["text" ] += 1
42+
43+
44+ def count_message_types (messages : List [Dict [str , Any ]]) -> Dict [str , int ]:
45+ """
46+ Count the number of each message type in a list of messages.
47+
48+ Args:
49+ messages: List of message dictionaries
50+
51+ Returns:
52+ Dictionary mapping message types to counts
53+ """
54+ message_counts : DefaultDict [str , int ] = defaultdict (int )
55+
56+ for message in messages :
57+ _count_message_content (message , message_counts )
58+
59+ return dict (message_counts )
60+
61+
62+ def _get_usage_dict (usage : Usage , model : str , messages : List [Dict [str , Any ]]) -> Dict :
2463 data = dict ()
2564 data ["_prompt_tokens" ] = usage .prompt_tokens
2665
@@ -47,6 +86,7 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
4786 data ["_prompt_cost_usd" ] = None
4887 data ["_completion_cost_usd" ] = None
4988 data ["_total_cost_usd" ] = None
89+ data ["_message_counts" ] = count_message_types (messages )
5090 return data
5191
5292
@@ -141,34 +181,6 @@ def handle_llm_exception(
141181 return _log_llm_exception (e ), usage
142182
143183
144- def _ensure_messages_fit_in_context_window (
145- messages : List [Dict [str , str ]], model : str
146- ) -> Tuple [List [Dict [str , str ]], int ]:
147- """
148- Ensure that the messages fit in the context window of the model.
149- """
150- token_count = token_counter (model = model , messages = messages )
151- logger .debug (f"Prompt tokens count: { token_count } " )
152-
153- if model in litellm .model_cost :
154- # If we are able to identify the model context window, ensure the messages fit in it
155- max_tokens = litellm .model_cost [model ].get (
156- "max_input_tokens" , litellm .model_cost [model ]["max_tokens" ]
157- )
158- if token_count > max_tokens :
159- logger .info (
160- f"Prompt tokens count { token_count } exceeds max tokens { max_tokens } for model { model } . Trimming messages."
161- )
162- # TODO: in case it exceeds max tokens, content of the last message is truncated.
163- # to improve this, we implement:
164- # - UI-level warning for the user, use prediction_meta field for warnings as well as errors in future
165- # - sequential aggregation instead of trimming
166- # - potential v2 solution to downsample images instead of cutting them off (using quality=low instead of quality=auto in completion)
167- return trim_messages (messages , model = model ), token_count
168- # in other cases, just return the original messages
169- return messages , token_count
170-
171-
172184@debug_time_it
173185def run_instructor_with_messages (
174186 client : Instructor ,
@@ -203,11 +215,6 @@ def run_instructor_with_messages(
203215 Dict containing the parsed response and usage information
204216 """
205217 try :
206- prompt_token_count = None
207- if ensure_messages_fit_in_context_window :
208- messages , prompt_token_count = _ensure_messages_fit_in_context_window (
209- messages , canonical_model_provider_string or model
210- )
211218
212219 response , completion = client .chat .completions .create_with_completion (
213220 messages = messages ,
@@ -225,14 +232,15 @@ def run_instructor_with_messages(
225232 usage_model = completion .model
226233
227234 except Exception as e :
228- dct , usage = handle_llm_exception (
229- e , messages , model , retries , prompt_token_count = prompt_token_count
230- )
235+ dct , usage = handle_llm_exception (e , messages , model , retries )
231236 # With exceptions we don't have access to completion.model
232237 usage_model = canonical_model_provider_string or model
238+ # Add empty message counts in case of exception
233239
234240 # Add usage data to the response (e.g. token counts, cost)
235- dct .update (_get_usage_dict (usage , model = usage_model ))
241+ usage_data = _get_usage_dict (usage , model = usage_model , messages = messages )
242+ # Add message counts to usage data
243+ dct .update (usage_data )
236244
237245 return dct
238246
@@ -248,7 +256,6 @@ async def arun_instructor_with_messages(
248256 temperature : Optional [float ] = None ,
249257 seed : Optional [int ] = None ,
250258 retries : Optional [AsyncRetrying ] = None ,
251- ensure_messages_fit_in_context_window : bool = False ,
252259 ** kwargs ,
253260) -> Dict [str , Any ]:
254261 """
@@ -264,18 +271,12 @@ async def arun_instructor_with_messages(
264271 temperature: Temperature for sampling
265272 seed: Integer seed to reduce nondeterminism
266273 retries: Retry policy to use
267- ensure_messages_fit_in_context_window: Whether to ensure the messages fit in the context window (setting it to True will slow down the function)
268274 **kwargs: Additional arguments to pass to the completion call
269275
270276 Returns:
271277 Dict containing the parsed response and usage information
272278 """
273279 try :
274- prompt_token_count = None
275- if ensure_messages_fit_in_context_window :
276- messages , prompt_token_count = _ensure_messages_fit_in_context_window (
277- messages , canonical_model_provider_string or model
278- )
279280
280281 response , completion = await client .chat .completions .create_with_completion (
281282 messages = messages ,
@@ -293,14 +294,13 @@ async def arun_instructor_with_messages(
293294 usage_model = completion .model
294295
295296 except Exception as e :
296- dct , usage = handle_llm_exception (
297- e , messages , model , retries , prompt_token_count = prompt_token_count
298- )
297+ dct , usage = handle_llm_exception (e , messages , model , retries )
299298 # With exceptions we don't have access to completion.model
300299 usage_model = canonical_model_provider_string or model
301300
302301 # Add usage data to the response (e.g. token counts, cost)
303- dct .update (_get_usage_dict (usage , model = usage_model ))
302+ usage_data = _get_usage_dict (usage , model = usage_model , messages = messages )
303+ dct .update (usage_data )
304304
305305 return dct
306306
@@ -359,7 +359,7 @@ def run_instructor_with_payload(
359359 split_into_chunks = split_into_chunks ,
360360 )
361361
362- messages = messages_builder .get_messages (payload )
362+ messages = messages_builder .get_messages (payload ). messages
363363 return run_instructor_with_messages (
364364 client ,
365365 messages ,
@@ -428,7 +428,7 @@ async def arun_instructor_with_payload(
428428 split_into_chunks = split_into_chunks ,
429429 )
430430
431- messages = messages_builder .get_messages (payload )
431+ messages = messages_builder .get_messages (payload ). messages
432432 return await arun_instructor_with_messages (
433433 client ,
434434 messages ,
@@ -500,7 +500,7 @@ def run_instructor_with_payloads(
500500
501501 results = []
502502 for payload in payloads :
503- messages = messages_builder .get_messages (payload )
503+ messages = messages_builder .get_messages (payload ). messages
504504 result = run_instructor_with_messages (
505505 client ,
506506 messages ,
@@ -571,23 +571,26 @@ async def arun_instructor_with_payloads(
571571 input_field_types = input_field_types ,
572572 extra_fields = extra_fields ,
573573 split_into_chunks = split_into_chunks ,
574+ trim_to_fit_context = ensure_messages_fit_in_context_window ,
575+ model = canonical_model_provider_string or model ,
574576 )
575577
576- tasks = [
577- arun_instructor_with_messages (
578- client ,
579- messages_builder .get_messages (payload ),
580- response_model ,
581- model ,
582- canonical_model_provider_string ,
583- max_tokens ,
584- temperature ,
585- seed ,
586- retries ,
587- ensure_messages_fit_in_context_window = ensure_messages_fit_in_context_window ,
588- ** kwargs ,
578+ tasks = []
579+ for payload in payloads :
580+ messages = messages_builder .get_messages (payload ).messages
581+ tasks .append (
582+ arun_instructor_with_messages (
583+ client ,
584+ messages ,
585+ response_model ,
586+ model ,
587+ canonical_model_provider_string ,
588+ max_tokens ,
589+ temperature ,
590+ seed ,
591+ retries ,
592+ ** kwargs ,
593+ )
589594 )
590- for payload in payloads
591- ]
592595
593596 return await asyncio .gather (* tasks )
0 commit comments