Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cursorignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
141 changes: 72 additions & 69 deletions adala/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,62 @@
import litellm
from litellm import token_counter
from collections import defaultdict
from typing import Any, Dict, List, Type, Optional, Tuple
from typing import Any, Dict, List, Type, Optional, Tuple, DefaultDict
from pydantic import BaseModel, Field
from pydantic_core import to_jsonable_python
from litellm.types.utils import Usage
from litellm.utils import trim_messages
from tenacity import Retrying, AsyncRetrying
from instructor.exceptions import InstructorRetryException, IncompleteOutputException
from instructor.client import Instructor, AsyncInstructor
from adala.utils.parse import MessagesBuilder, MessageChunkType
from adala.utils.parse import MessageChunkType
from adala.utils.message_builder import MessagesBuilder
from adala.utils.exceptions import ConstrainedGenerationError
from adala.utils.types import debug_time_it
from litellm.exceptions import BadRequestError

logger = logging.getLogger(__name__)


def _get_usage_dict(usage: Usage, model: str) -> Dict:
def _count_message_content(
message: Dict[str, Any], counts: DefaultDict[str, int]
) -> None:
"""Helper method to count different content types in a message."""
if "role" in message and "content" in message:
content = message["content"]
if isinstance(content, str):
counts["text"] += 1
elif isinstance(content, list):
for content_part in content:
if isinstance(content_part, dict) and "type" in content_part:
counts[content_part["type"]] += 1
else:
counts["text"] += 1
elif "type" in message:
counts[message["type"]] += 1
else:
counts["text"] += 1


def count_message_types(messages: List[Dict[str, Any]]) -> Dict[str, int]:
"""
Count the number of each message type in a list of messages.

Args:
messages: List of message dictionaries

Returns:
Dictionary mapping message types to counts
"""
message_counts: DefaultDict[str, int] = defaultdict(int)

for message in messages:
_count_message_content(message, message_counts)

return dict(message_counts)


def _get_usage_dict(usage: Usage, model: str, messages: List[Dict[str, Any]]) -> Dict:
data = dict()
data["_prompt_tokens"] = usage.prompt_tokens

Expand All @@ -47,6 +86,7 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
data["_prompt_cost_usd"] = None
data["_completion_cost_usd"] = None
data["_total_cost_usd"] = None
data["_message_counts"] = count_message_types(messages)
return data


Expand Down Expand Up @@ -141,34 +181,6 @@ def handle_llm_exception(
return _log_llm_exception(e), usage


def _ensure_messages_fit_in_context_window(
messages: List[Dict[str, str]], model: str
) -> Tuple[List[Dict[str, str]], int]:
"""
Ensure that the messages fit in the context window of the model.
"""
token_count = token_counter(model=model, messages=messages)
logger.debug(f"Prompt tokens count: {token_count}")

if model in litellm.model_cost:
# If we are able to identify the model context window, ensure the messages fit in it
max_tokens = litellm.model_cost[model].get(
"max_input_tokens", litellm.model_cost[model]["max_tokens"]
)
if token_count > max_tokens:
logger.info(
f"Prompt tokens count {token_count} exceeds max tokens {max_tokens} for model {model}. Trimming messages."
)
# TODO: in case it exceeds max tokens, content of the last message is truncated.
# to improve this, we implement:
# - UI-level warning for the user, use prediction_meta field for warnings as well as errors in future
# - sequential aggregation instead of trimming
# - potential v2 solution to downsample images instead of cutting them off (using quality=low instead of quality=auto in completion)
return trim_messages(messages, model=model), token_count
# in other cases, just return the original messages
return messages, token_count


@debug_time_it
def run_instructor_with_messages(
client: Instructor,
Expand Down Expand Up @@ -203,11 +215,6 @@ def run_instructor_with_messages(
Dict containing the parsed response and usage information
"""
try:
prompt_token_count = None
if ensure_messages_fit_in_context_window:
messages, prompt_token_count = _ensure_messages_fit_in_context_window(
messages, canonical_model_provider_string or model
)

response, completion = client.chat.completions.create_with_completion(
messages=messages,
Expand All @@ -225,14 +232,15 @@ def run_instructor_with_messages(
usage_model = completion.model

except Exception as e:
dct, usage = handle_llm_exception(
e, messages, model, retries, prompt_token_count=prompt_token_count
)
dct, usage = handle_llm_exception(e, messages, model, retries)
# With exceptions we don't have access to completion.model
usage_model = canonical_model_provider_string or model
# Add empty message counts in case of exception

# Add usage data to the response (e.g. token counts, cost)
dct.update(_get_usage_dict(usage, model=usage_model))
usage_data = _get_usage_dict(usage, model=usage_model, messages=messages)
# Add message counts to usage data
dct.update(usage_data)

return dct

Expand All @@ -248,7 +256,6 @@ async def arun_instructor_with_messages(
temperature: Optional[float] = None,
seed: Optional[int] = None,
retries: Optional[AsyncRetrying] = None,
ensure_messages_fit_in_context_window: bool = False,
**kwargs,
) -> Dict[str, Any]:
"""
Expand All @@ -264,18 +271,12 @@ async def arun_instructor_with_messages(
temperature: Temperature for sampling
seed: Integer seed to reduce nondeterminism
retries: Retry policy to use
ensure_messages_fit_in_context_window: Whether to ensure the messages fit in the context window (setting it to True will slow down the function)
**kwargs: Additional arguments to pass to the completion call

Returns:
Dict containing the parsed response and usage information
"""
try:
prompt_token_count = None
if ensure_messages_fit_in_context_window:
messages, prompt_token_count = _ensure_messages_fit_in_context_window(
messages, canonical_model_provider_string or model
)

response, completion = await client.chat.completions.create_with_completion(
messages=messages,
Expand All @@ -293,14 +294,13 @@ async def arun_instructor_with_messages(
usage_model = completion.model

except Exception as e:
dct, usage = handle_llm_exception(
e, messages, model, retries, prompt_token_count=prompt_token_count
)
dct, usage = handle_llm_exception(e, messages, model, retries)
# With exceptions we don't have access to completion.model
usage_model = canonical_model_provider_string or model

# Add usage data to the response (e.g. token counts, cost)
dct.update(_get_usage_dict(usage, model=usage_model))
usage_data = _get_usage_dict(usage, model=usage_model, messages=messages)
dct.update(usage_data)

return dct

Expand Down Expand Up @@ -359,7 +359,7 @@ def run_instructor_with_payload(
split_into_chunks=split_into_chunks,
)

messages = messages_builder.get_messages(payload)
messages = messages_builder.get_messages(payload).messages
return run_instructor_with_messages(
client,
messages,
Expand Down Expand Up @@ -428,7 +428,7 @@ async def arun_instructor_with_payload(
split_into_chunks=split_into_chunks,
)

messages = messages_builder.get_messages(payload)
messages = messages_builder.get_messages(payload).messages
return await arun_instructor_with_messages(
client,
messages,
Expand Down Expand Up @@ -500,7 +500,7 @@ def run_instructor_with_payloads(

results = []
for payload in payloads:
messages = messages_builder.get_messages(payload)
messages = messages_builder.get_messages(payload).messages
result = run_instructor_with_messages(
client,
messages,
Expand Down Expand Up @@ -571,23 +571,26 @@ async def arun_instructor_with_payloads(
input_field_types=input_field_types,
extra_fields=extra_fields,
split_into_chunks=split_into_chunks,
trim_to_fit_context=ensure_messages_fit_in_context_window,
model=canonical_model_provider_string or model,
)

tasks = [
arun_instructor_with_messages(
client,
messages_builder.get_messages(payload),
response_model,
model,
canonical_model_provider_string,
max_tokens,
temperature,
seed,
retries,
ensure_messages_fit_in_context_window=ensure_messages_fit_in_context_window,
**kwargs,
tasks = []
for payload in payloads:
messages = messages_builder.get_messages(payload).messages
tasks.append(
arun_instructor_with_messages(
client,
messages,
response_model,
model,
canonical_model_provider_string,
max_tokens,
temperature,
seed,
retries,
**kwargs,
)
)
for payload in payloads
]

return await asyncio.gather(*tasks)
Loading
Loading