Skip to content

fix: DIA-1986: New model providers don't bubble up task-level errors #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 6, 2025
198 changes: 52 additions & 146 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
run_instructor_with_messages,
arun_instructor_with_messages,
)
from adala.utils.model_info_utils import (
match_model_provider_string,
NoModelsFoundError,
_estimate_cost,
)

from pydantic import ConfigDict, field_validator, BaseModel
from pydantic_core import to_jsonable_python
Expand Down Expand Up @@ -85,27 +90,6 @@
async_retries = AsyncRetrying(**RETRY_POLICY)


def normalize_litellm_model_and_provider(model_name: str, provider: str):
"""
When using litellm.get_model_info() some models are accessed with their provider prefix
while others are not.

This helper function contains logic which normalizes this for supported providers
"""
if "/" in model_name:
model_name = model_name.split("/", 1)[1]
provider = provider.lower()
# TODO: move this logic to LSE, this is the last place Adala needs to be updated when adding a provider connection
if provider == "vertexai":
provider = "vertex_ai"
if provider == "azureopenai":
provider = "azure"
if provider == "azureaifoundry":
provider = "azure_ai"

return model_name, provider


class InstructorClientMixin(BaseModel):

# Note: most models work better with json mode; this is set only for backwards compatibility
Expand All @@ -124,18 +108,19 @@ def _openai_client(self):
return OpenAI

def _check_client(self):
run_instructor_with_messages(
client=self.client,
# don't use response model and error handling from run_instructor_with_messages here
response = self.client.chat.completions.create(
messages=[{"role": "user", "content": "Hey, how's it going?"}],
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
response_model=None,
retries=retries,
max_retries=retries,
# extra inference params passed to this runtime
**self.model_extra,
)
return response

# Yes, this is recomputed every time - there's no clean way to cache it unless we drop pickle serialization, in which case adding it to ConfigDict(ignore) would work.
# There's no appreciable startup cost in the instructor client init function anyway.
Expand Down Expand Up @@ -176,6 +161,33 @@ def init_runtime(self) -> "Runtime":

return self

def get_canonical_model_provider_string(self, model: str) -> str:
"""provider_name/model_name"""
# this is really a litellm function, not an instructor function. Putting it here to avoid duplicating it between sync/async runtimes.
try:
return match_model_provider_string(model)
except NoModelsFoundError:
logger.info(
f"Model {model} not found in litellm model map for provider {self.provider}. This is likely a single-model deployment."
)
except Exception as e:
logger.exception(
f"(1/2) Failed to get canonical model provider string for {model}"
)
try:
resp = self._check_client()
return match_model_provider_string(resp.model)
except NoModelsFoundError:
logger.warning(
f"Model {model} not found in litellm model map for provider {self.provider}. This is likely a custom model."
)
return model
except Exception as e:
logger.exception(
f"(2/2) Failed to get canonical model provider string for {model}"
)
return model


class InstructorAsyncClientMixin(InstructorClientMixin):

Expand All @@ -190,18 +202,19 @@ def _openai_client(self):
def _check_client(self):
"""Make this synchronous"""
client = InstructorClientMixin(**self.model_dump()).client
run_instructor_with_messages(
client=client,
# don't use response model and error handling from run_instructor_with_messages here
response = client.chat.completions.create(
messages=[{"role": "user", "content": "Hey, how's it going?"}],
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
seed=self.seed,
response_model=None,
retries=retries,
max_retries=retries,
# extra inference params passed to this runtime
**self.model_extra,
)
return response


class LiteLLMChatRuntime(InstructorClientMixin, Runtime):
Expand Down Expand Up @@ -232,6 +245,11 @@ class LiteLLMChatRuntime(InstructorClientMixin, Runtime):

model_config = ConfigDict(extra="allow")

@property
def canonical_model_provider_string(self):
"""provider_name/model_name"""
return self.get_canonical_model_provider_string(self.model)

def get_llm_response(self, messages: List[Dict[str, str]]) -> str:
# TODO: sunset this method in favor of record_to_record
if self.verbose:
Expand Down Expand Up @@ -342,6 +360,11 @@ def check_concurrency(cls, value) -> int:
)
return value

@property
def canonical_model_provider_string(self):
"""provider_name/model_name"""
return self.get_canonical_model_provider_string(self.model)

async def batch_to_batch(
self,
batch: InternalDataFrame,
Expand Down Expand Up @@ -430,83 +453,6 @@ async def record_to_record(
# Extract the single row from the output DataFrame and convert it to a dictionary
return output_df.iloc[0].to_dict()

@staticmethod
def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int:
user_tokens = litellm.token_counter(model=model, text=string)
# FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us.
# currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that.
system_tokens = 56 + (6 * len(output_fields))
return user_tokens + system_tokens

@staticmethod
def _get_completion_tokens(
candidate_model_names: List[str],
output_fields: Optional[List[str]],
provider: str,
) -> int:
max_tokens = None
for model in candidate_model_names:
try:
max_tokens = litellm.get_model_info(model=model).get("max_tokens", None)
break
except Exception as e:
if "model isn't mapped" in str(e):
continue
else:
raise e
if not max_tokens:
raise ValueError
# extremely rough heuristic, from testing on some anecdotal examples
n_outputs = len(output_fields) if output_fields else 1
return min(max_tokens, 4 * n_outputs)

@classmethod
def _estimate_cost(
cls,
user_prompt: str,
model: str,
output_fields: Optional[List[str]],
provider: str,
):
prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields)
# amazingly, litellm.cost_per_token refers to a hardcoded dictionary litellm.model_cost which is case-sensitive with inconsistent casing.....
# Example: 'azure_ai/deepseek-r1' vs 'azure_ai/Llama-3.3-70B-Instruct'
# so we have no way of determining the correct casing or reliably fixing it.
# we can at least try all-lowercase.
candidate_model_names = [model, model.lower()]
# ...and Azure AI Foundry openai models are not listed there, but under Azure OpenAI
if model.startswith("azure_ai/"):
candidate_model_names.append(model.replace("azure_ai/", "azure/"))
candidate_model_names.append(model.replace("azure_ai/", "azure/").lower())
candidate_model_names = list(set(candidate_model_names))

completion_tokens = cls._get_completion_tokens(
candidate_model_names, output_fields, provider
)

prompt_cost, completion_cost = None, None
for candidate_model_name in candidate_model_names:
try:
prompt_cost, completion_cost = litellm.cost_per_token(
model=candidate_model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
except Exception as e:
# it also doesn't have a type to catch:
# Exception("This model isn't mapped yet. model=azure_ai/deepseek-R1, custom_llm_provider=azure_ai. Add it here - https://github.com/ BerriAI/litellm/blob/main/model_prices_and_context_window.json.")
if "model isn't mapped" in str(e):
pass
if prompt_cost is not None and completion_cost is not None:
break

if prompt_cost is None or completion_cost is None:
raise ValueError(f"Model {model} for provider {provider} not found.")

total_cost = prompt_cost + completion_cost

return prompt_cost, completion_cost, total_cost

def get_cost_estimate(
self,
prompt: str,
Expand All @@ -522,23 +468,9 @@ def get_cost_estimate(
cumulative_prompt_cost = 0
cumulative_completion_cost = 0
cumulative_total_cost = 0
# for azure, we need the canonical model name, not the deployment name
if self.model.startswith("azure/"):
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = litellm.completion(
messages=messages,
model=self.model,
max_tokens=10,
temperature=self.temperature,
seed=self.seed,
# extra inference params passed to this runtime
**self.model_extra,
)
model = "azure/" + response.model
else:
model = self.model
model = self.canonical_model_provider_string
for user_prompt in user_prompts:
prompt_cost, completion_cost, total_cost = self._estimate_cost(
prompt_cost, completion_cost, total_cost = _estimate_cost(
user_prompt=user_prompt,
model=model,
output_fields=output_fields,
Expand Down Expand Up @@ -622,29 +554,3 @@ async def batch_to_batch(

output_df = InternalDataFrame(df_data)
return output_df.set_index(batch.index)

# TODO: cost estimate


def get_model_info(
provider: str, model_name: str, auth_info: Optional[dict] = None
) -> dict:
if auth_info is None:
auth_info = {}
try:
# for azure models, need to get the canonical name for the model
if provider == "azure":
dummy_completion = litellm.completion(
model=f"azure/{model_name}",
messages=[{"role": "user", "content": ""}],
max_tokens=1,
**auth_info,
)
model_name = dummy_completion.model
model_name, provider = normalize_litellm_model_and_provider(
model_name, provider
)
return litellm.get_model_info(model=model_name, custom_llm_provider=provider)
except Exception as err:
logger.error("Hit error when trying to get model metadata: %s", err)
return {}
4 changes: 1 addition & 3 deletions adala/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import logging
import traceback
import litellm
from litellm import token_counter
from collections import defaultdict
from typing import Any, Dict, List, Type, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
from pydantic_core import to_jsonable_python
from litellm.types.utils import Usage
from tenacity import Retrying, AsyncRetrying
Expand Down
Loading