Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 105 additions & 70 deletions extract_thinker/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,19 @@ class LLM:
MIN_THINKING_BUDGET = 1200 # Minimum thinking budget
DEFAULT_OUTPUT_TOKENS = 32000

# Model-specific token limits
MODEL_TOKEN_LIMITS = {
"gpt-4o": 12000, # Reasonable middle ground for GPT-4o
}

def __init__(
self,
model: str,
token_limit: int = None,
backend: LLMEngine = LLMEngine.DEFAULT
):
"""Initialize LLM with specified backend.

Args:
model: The model name (e.g. "gpt-4", "claude-3")
token_limit: Optional maximum tokens
Expand Down Expand Up @@ -69,7 +74,8 @@ def __init__(
from pydantic_ai import Agent
from pydantic_ai.models import KnownModelName
from typing import cast

import asyncio

self.client = None
self.agent = Agent(
cast(KnownModelName, self.model)
Expand Down Expand Up @@ -100,6 +106,15 @@ def _get_pydantic_ai():
"Please install it with `pip install pydantic-ai`."
)

def _get_model_max_tokens(self) -> int:
"""Get the maximum tokens allowed for the current model."""
# Only apply special limit for GPT-4o
if self.model == "gpt-4o":
return self.MODEL_TOKEN_LIMITS["gpt-4o"]

# Default to the general MAX_TOKEN_LIMIT for all other models
return self.MAX_TOKEN_LIMIT

def load_router(self, router: Router) -> None:
"""Load a LiteLLM router for model fallbacks."""
if self.backend != LLMEngine.DEFAULT:
Expand All @@ -108,15 +123,15 @@ def load_router(self, router: Router) -> None:

def set_temperature(self, temperature: float) -> None:
"""Set the temperature for LLM requests.

Args:
temperature (float): Temperature value between 0 and 1
"""
self.temperature = temperature

def set_thinking(self, is_thinking: bool) -> None:
"""Set whether the LLM should handle thinking.

Args:
is_thinking (bool): Whether to enable thinking
"""
Expand All @@ -125,39 +140,39 @@ def set_thinking(self, is_thinking: bool) -> None:

def set_dynamic(self, is_dynamic: bool) -> None:
"""Set whether the LLM should handle dynamic content.

When dynamic is True, the LLM will attempt to parse and validate JSON responses.
This is useful for handling structured outputs like masking mappings.

Args:
is_dynamic (bool): Whether to enable dynamic content handling
"""
self.is_dynamic = is_dynamic

def set_page_count(self, page_count: int) -> None:
"""Set the page count to calculate token limits for thinking.

Each page is assumed to have DEFAULT_PAGE_TOKENS tokens (text + image).
Thinking budget is calculated as DEFAULT_THINKING_RATIO of the content tokens.

Args:
page_count (int): Number of pages in the document
"""
if page_count <= 0:
raise ValueError("Page count must be a positive integer")

self.page_count = page_count

# Calculate content tokens
content_tokens = min(page_count * self.DEFAULT_PAGE_TOKENS, self.MAX_TOKEN_LIMIT)

# Calculate thinking budget (1/3 of content tokens)
thinking_tokens = int(page_count * self.DEFAULT_PAGE_TOKENS * self.DEFAULT_THINKING_RATIO)

# Apply min/max constraints
thinking_tokens = max(thinking_tokens, self.MIN_THINKING_BUDGET)
thinking_tokens = min(thinking_tokens, self.MAX_THINKING_BUDGET)

# Update token limit and thinking budget
self.thinking_token_limit = content_tokens
self.thinking_budget = thinking_tokens
Expand All @@ -172,21 +187,23 @@ def request(
# Combine messages into a single prompt
combined_prompt = " ".join([m["content"] for m in messages])
try:
result = asyncio.run(
# Create event loop if it doesn't exist
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

result = loop.run_until_complete(
self.agent.run(
combined_prompt,
combined_prompt,
result_type=response_model if response_model else str
)
)
return result.data
except Exception as e:
raise ValueError(f"Failed to extract from source: {str(e)}")

# Uncomment the following lines if you need to calculate max_tokens
# contents = map(lambda message: message['content'], messages)
# all_contents = ' '.join(contents)
# max_tokens = num_tokens_from_string(all_contents)

# if is sync, response model is None if dynamic true and used for dynamic parsing after llm request
request_model = None if self.is_dynamic else response_model

Expand Down Expand Up @@ -214,7 +231,7 @@ def request(
content = response.choices[0].message.content
if self.is_dynamic:
return extract_thinking_json(content, response_model)

return content

def _request_with_router(self, messages: List[Dict[str, str]], response_model: Optional[str]) -> Any:
Expand All @@ -234,18 +251,43 @@ def _request_with_router(self, messages: List[Dict[str, str]], response_model: O
"max_completion_tokens": max_tokens,
}
if self.is_thinking:
if litellm.supports_reasoning(self.model):
# Add thinking parameter for supported models
thinking_param = {
"type": "enabled",
"budget_tokens": self.thinking_budget
}
params["thinking"] = thinking_param
else:
print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.")
# Add thinking parameter for supported models
thinking_param = {
"type": "enabled",
"budget_tokens": self.thinking_budget
}
try:
return self.router.completion(
model=self.model,
messages=messages,
response_model=response_model,
temperature=self.temperature,
timeout=self.TIMEOUT,
thinking=thinking_param,
)
except Exception as e:
# If thinking parameter causes an error, try without it
if "property 'thinking' is unsupported" in str(e):
print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.")
return self.router.completion(
model=self.model,
messages=messages,
response_model=response_model,
temperature=self.temperature,
timeout=self.TIMEOUT,
)
else:
raise e
else:
# Normal request without thinking parameter
return self.router.completion(
model=self.model,
messages=messages,
response_model=response_model,
temperature=self.temperature,
timeout=self.TIMEOUT,
)

return self.router.completion(**params)

def _request_direct(self, messages: List[Dict[str, str]], response_model: Optional[str]) -> Any:
"""Handle direct request with or without thinking parameter"""
max_tokens = self.DEFAULT_OUTPUT_TOKENS
Expand All @@ -260,10 +302,10 @@ def _request_direct(self, messages: List[Dict[str, str]], response_model: Option
"temperature": self.temperature,
"response_model": response_model,
"max_retries": 1,
"max_completion_tokens": max_tokens,
"max_tokens": self._get_model_max_tokens(), # <- capped max tokens here
"timeout": self.TIMEOUT,
}

if self.is_thinking:
if litellm.supports_reasoning(self.model):
# Try with thinking parameter
Expand All @@ -279,50 +321,43 @@ def _request_direct(self, messages: List[Dict[str, str]], response_model: Option

def raw_completion(self, messages: List[Dict[str, str]]) -> str:
"""Make raw completion request without response model."""
if self.backend == LLMEngine.PYDANTIC_AI:
# Combine messages into a single prompt
combined_prompt = " ".join([m["content"] for m in messages])
try:
result = asyncio.run(
self.agent.run(
combined_prompt,
result_type=str
)
)
return result.data
except Exception as e:
raise ValueError(f"Failed to extract from source: {str(e)}")
max_tokens = self._get_model_max_tokens() # <- capped max tokens here

max_tokens = self.DEFAULT_OUTPUT_TOKENS
if self.token_limit is not None:
max_tokens = self.token_limit
elif self.is_thinking:
max_tokens = self.thinking_token_limit

params = {
"model": self.model,
"messages": messages,
"max_completion_tokens": max_tokens,
}

if self.is_thinking:
if litellm.supports_reasoning(self.model):
if self.router:
raw_response = self.router.completion(**params)
else:
if self.is_thinking:
# Add thinking parameter for supported models
thinking_param = {
"type": "enabled",
"budget_tokens": self.thinking_budget
}
params["thinking"] = thinking_param
try:
raw_response = litellm.completion(
model=self.model,
messages=messages,
max_tokens=max_tokens,
thinking=thinking_param,
)
except Exception as e:
# If thinking parameter causes an error, try without it
if "property 'thinking' is unsupported" in str(e):
print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.")
raw_response = litellm.completion(
model=self.model,
messages=messages,
max_tokens=max_tokens,
)
else:
raise e
else:
print(f"Warning: Model {self.model} doesn't support thinking parameter, proceeding without it.")

if self.router:
raw_response = self.router.completion(**params)
else:
raw_response = litellm.completion(**params)

raw_response = litellm.completion(
model=self.model,
messages=messages,
max_tokens=max_tokens,
)
return raw_response.choices[0].message.content

def set_timeout(self, timeout_ms: int) -> None:
"""Set the timeout value for LLM requests in milliseconds."""
self.TIMEOUT = timeout_ms
self.TIMEOUT = timeout_ms
Loading