Skip to content

Commit a20f2e6

Browse files
niklubnik
and
nik
authored
feat: DIA-2067: Add message type counters in MessageBuilder (#369)
Co-authored-by: nik <[email protected]>
1 parent 9bb28c2 commit a20f2e6

18 files changed

+6502
-449
lines changed

Diff for: .cursorignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)

Diff for: adala/utils/llm_utils.py

+72-69
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,62 @@
44
import litellm
55
from litellm import token_counter
66
from collections import defaultdict
7-
from typing import Any, Dict, List, Type, Optional, Tuple
7+
from typing import Any, Dict, List, Type, Optional, Tuple, DefaultDict
88
from pydantic import BaseModel, Field
99
from pydantic_core import to_jsonable_python
1010
from litellm.types.utils import Usage
1111
from litellm.utils import trim_messages
1212
from tenacity import Retrying, AsyncRetrying
1313
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
1414
from instructor.client import Instructor, AsyncInstructor
15-
from adala.utils.parse import MessagesBuilder, MessageChunkType
15+
from adala.utils.parse import MessageChunkType
16+
from adala.utils.message_builder import MessagesBuilder
1617
from adala.utils.exceptions import ConstrainedGenerationError
1718
from adala.utils.types import debug_time_it
1819
from litellm.exceptions import BadRequestError
1920

2021
logger = logging.getLogger(__name__)
2122

2223

23-
def _get_usage_dict(usage: Usage, model: str) -> Dict:
24+
def _count_message_content(
25+
message: Dict[str, Any], counts: DefaultDict[str, int]
26+
) -> None:
27+
"""Helper method to count different content types in a message."""
28+
if "role" in message and "content" in message:
29+
content = message["content"]
30+
if isinstance(content, str):
31+
counts["text"] += 1
32+
elif isinstance(content, list):
33+
for content_part in content:
34+
if isinstance(content_part, dict) and "type" in content_part:
35+
counts[content_part["type"]] += 1
36+
else:
37+
counts["text"] += 1
38+
elif "type" in message:
39+
counts[message["type"]] += 1
40+
else:
41+
counts["text"] += 1
42+
43+
44+
def count_message_types(messages: List[Dict[str, Any]]) -> Dict[str, int]:
45+
"""
46+
Count the number of each message type in a list of messages.
47+
48+
Args:
49+
messages: List of message dictionaries
50+
51+
Returns:
52+
Dictionary mapping message types to counts
53+
"""
54+
message_counts: DefaultDict[str, int] = defaultdict(int)
55+
56+
for message in messages:
57+
_count_message_content(message, message_counts)
58+
59+
return dict(message_counts)
60+
61+
62+
def _get_usage_dict(usage: Usage, model: str, messages: List[Dict[str, Any]]) -> Dict:
2463
data = dict()
2564
data["_prompt_tokens"] = usage.prompt_tokens
2665

@@ -47,6 +86,7 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
4786
data["_prompt_cost_usd"] = None
4887
data["_completion_cost_usd"] = None
4988
data["_total_cost_usd"] = None
89+
data["_message_counts"] = count_message_types(messages)
5090
return data
5191

5292

@@ -141,34 +181,6 @@ def handle_llm_exception(
141181
return _log_llm_exception(e), usage
142182

143183

144-
def _ensure_messages_fit_in_context_window(
145-
messages: List[Dict[str, str]], model: str
146-
) -> Tuple[List[Dict[str, str]], int]:
147-
"""
148-
Ensure that the messages fit in the context window of the model.
149-
"""
150-
token_count = token_counter(model=model, messages=messages)
151-
logger.debug(f"Prompt tokens count: {token_count}")
152-
153-
if model in litellm.model_cost:
154-
# If we are able to identify the model context window, ensure the messages fit in it
155-
max_tokens = litellm.model_cost[model].get(
156-
"max_input_tokens", litellm.model_cost[model]["max_tokens"]
157-
)
158-
if token_count > max_tokens:
159-
logger.info(
160-
f"Prompt tokens count {token_count} exceeds max tokens {max_tokens} for model {model}. Trimming messages."
161-
)
162-
# TODO: in case it exceeds max tokens, content of the last message is truncated.
163-
# to improve this, we implement:
164-
# - UI-level warning for the user, use prediction_meta field for warnings as well as errors in future
165-
# - sequential aggregation instead of trimming
166-
# - potential v2 solution to downsample images instead of cutting them off (using quality=low instead of quality=auto in completion)
167-
return trim_messages(messages, model=model), token_count
168-
# in other cases, just return the original messages
169-
return messages, token_count
170-
171-
172184
@debug_time_it
173185
def run_instructor_with_messages(
174186
client: Instructor,
@@ -203,11 +215,6 @@ def run_instructor_with_messages(
203215
Dict containing the parsed response and usage information
204216
"""
205217
try:
206-
prompt_token_count = None
207-
if ensure_messages_fit_in_context_window:
208-
messages, prompt_token_count = _ensure_messages_fit_in_context_window(
209-
messages, canonical_model_provider_string or model
210-
)
211218

212219
response, completion = client.chat.completions.create_with_completion(
213220
messages=messages,
@@ -225,14 +232,15 @@ def run_instructor_with_messages(
225232
usage_model = completion.model
226233

227234
except Exception as e:
228-
dct, usage = handle_llm_exception(
229-
e, messages, model, retries, prompt_token_count=prompt_token_count
230-
)
235+
dct, usage = handle_llm_exception(e, messages, model, retries)
231236
# With exceptions we don't have access to completion.model
232237
usage_model = canonical_model_provider_string or model
238+
# Add empty message counts in case of exception
233239

234240
# Add usage data to the response (e.g. token counts, cost)
235-
dct.update(_get_usage_dict(usage, model=usage_model))
241+
usage_data = _get_usage_dict(usage, model=usage_model, messages=messages)
242+
# Add message counts to usage data
243+
dct.update(usage_data)
236244

237245
return dct
238246

@@ -248,7 +256,6 @@ async def arun_instructor_with_messages(
248256
temperature: Optional[float] = None,
249257
seed: Optional[int] = None,
250258
retries: Optional[AsyncRetrying] = None,
251-
ensure_messages_fit_in_context_window: bool = False,
252259
**kwargs,
253260
) -> Dict[str, Any]:
254261
"""
@@ -264,18 +271,12 @@ async def arun_instructor_with_messages(
264271
temperature: Temperature for sampling
265272
seed: Integer seed to reduce nondeterminism
266273
retries: Retry policy to use
267-
ensure_messages_fit_in_context_window: Whether to ensure the messages fit in the context window (setting it to True will slow down the function)
268274
**kwargs: Additional arguments to pass to the completion call
269275
270276
Returns:
271277
Dict containing the parsed response and usage information
272278
"""
273279
try:
274-
prompt_token_count = None
275-
if ensure_messages_fit_in_context_window:
276-
messages, prompt_token_count = _ensure_messages_fit_in_context_window(
277-
messages, canonical_model_provider_string or model
278-
)
279280

280281
response, completion = await client.chat.completions.create_with_completion(
281282
messages=messages,
@@ -293,14 +294,13 @@ async def arun_instructor_with_messages(
293294
usage_model = completion.model
294295

295296
except Exception as e:
296-
dct, usage = handle_llm_exception(
297-
e, messages, model, retries, prompt_token_count=prompt_token_count
298-
)
297+
dct, usage = handle_llm_exception(e, messages, model, retries)
299298
# With exceptions we don't have access to completion.model
300299
usage_model = canonical_model_provider_string or model
301300

302301
# Add usage data to the response (e.g. token counts, cost)
303-
dct.update(_get_usage_dict(usage, model=usage_model))
302+
usage_data = _get_usage_dict(usage, model=usage_model, messages=messages)
303+
dct.update(usage_data)
304304

305305
return dct
306306

@@ -359,7 +359,7 @@ def run_instructor_with_payload(
359359
split_into_chunks=split_into_chunks,
360360
)
361361

362-
messages = messages_builder.get_messages(payload)
362+
messages = messages_builder.get_messages(payload).messages
363363
return run_instructor_with_messages(
364364
client,
365365
messages,
@@ -428,7 +428,7 @@ async def arun_instructor_with_payload(
428428
split_into_chunks=split_into_chunks,
429429
)
430430

431-
messages = messages_builder.get_messages(payload)
431+
messages = messages_builder.get_messages(payload).messages
432432
return await arun_instructor_with_messages(
433433
client,
434434
messages,
@@ -500,7 +500,7 @@ def run_instructor_with_payloads(
500500

501501
results = []
502502
for payload in payloads:
503-
messages = messages_builder.get_messages(payload)
503+
messages = messages_builder.get_messages(payload).messages
504504
result = run_instructor_with_messages(
505505
client,
506506
messages,
@@ -571,23 +571,26 @@ async def arun_instructor_with_payloads(
571571
input_field_types=input_field_types,
572572
extra_fields=extra_fields,
573573
split_into_chunks=split_into_chunks,
574+
trim_to_fit_context=ensure_messages_fit_in_context_window,
575+
model=canonical_model_provider_string or model,
574576
)
575577

576-
tasks = [
577-
arun_instructor_with_messages(
578-
client,
579-
messages_builder.get_messages(payload),
580-
response_model,
581-
model,
582-
canonical_model_provider_string,
583-
max_tokens,
584-
temperature,
585-
seed,
586-
retries,
587-
ensure_messages_fit_in_context_window=ensure_messages_fit_in_context_window,
588-
**kwargs,
578+
tasks = []
579+
for payload in payloads:
580+
messages = messages_builder.get_messages(payload).messages
581+
tasks.append(
582+
arun_instructor_with_messages(
583+
client,
584+
messages,
585+
response_model,
586+
model,
587+
canonical_model_provider_string,
588+
max_tokens,
589+
temperature,
590+
seed,
591+
retries,
592+
**kwargs,
593+
)
589594
)
590-
for payload in payloads
591-
]
592595

593596
return await asyncio.gather(*tasks)

0 commit comments

Comments
 (0)