diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index 6b029db2..e88f0328 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -46,6 +46,11 @@ run_instructor_with_messages, arun_instructor_with_messages, ) +from adala.utils.model_info_utils import ( + match_model_provider_string, + NoModelsFoundError, + _estimate_cost, +) from pydantic import ConfigDict, field_validator, BaseModel from pydantic_core import to_jsonable_python @@ -85,27 +90,6 @@ async_retries = AsyncRetrying(**RETRY_POLICY) -def normalize_litellm_model_and_provider(model_name: str, provider: str): - """ - When using litellm.get_model_info() some models are accessed with their provider prefix - while others are not. - - This helper function contains logic which normalizes this for supported providers - """ - if "/" in model_name: - model_name = model_name.split("/", 1)[1] - provider = provider.lower() - # TODO: move this logic to LSE, this is the last place Adala needs to be updated when adding a provider connection - if provider == "vertexai": - provider = "vertex_ai" - if provider == "azureopenai": - provider = "azure" - if provider == "azureaifoundry": - provider = "azure_ai" - - return model_name, provider - - class InstructorClientMixin(BaseModel): # Note: most models work better with json mode; this is set only for backwards compatibility @@ -124,18 +108,19 @@ def _openai_client(self): return OpenAI def _check_client(self): - run_instructor_with_messages( - client=self.client, + # don't use response model and error handling from run_instructor_with_messages here + response = self.client.chat.completions.create( messages=[{"role": "user", "content": "Hey, how's it going?"}], model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, seed=self.seed, response_model=None, - retries=retries, + max_retries=retries, # extra inference params passed to this runtime **self.model_extra, ) + return response # Yes, this is recomputed every time - there's no clean way to cache it unless we drop pickle serialization, in which case adding it to ConfigDict(ignore) would work. # There's no appreciable startup cost in the instructor client init function anyway. @@ -176,6 +161,33 @@ def init_runtime(self) -> "Runtime": return self + def get_canonical_model_provider_string(self, model: str) -> str: + """provider_name/model_name""" + # this is really a litellm function, not an instructor function. Putting it here to avoid duplicating it between sync/async runtimes. + try: + return match_model_provider_string(model) + except NoModelsFoundError: + logger.info( + f"Model {model} not found in litellm model map for provider {self.provider}. This is likely a single-model deployment." + ) + except Exception as e: + logger.exception( + f"(1/2) Failed to get canonical model provider string for {model}" + ) + try: + resp = self._check_client() + return match_model_provider_string(resp.model) + except NoModelsFoundError: + logger.warning( + f"Model {model} not found in litellm model map for provider {self.provider}. This is likely a custom model." + ) + return model + except Exception as e: + logger.exception( + f"(2/2) Failed to get canonical model provider string for {model}" + ) + return model + class InstructorAsyncClientMixin(InstructorClientMixin): @@ -190,18 +202,19 @@ def _openai_client(self): def _check_client(self): """Make this synchronous""" client = InstructorClientMixin(**self.model_dump()).client - run_instructor_with_messages( - client=client, + # don't use response model and error handling from run_instructor_with_messages here + response = client.chat.completions.create( messages=[{"role": "user", "content": "Hey, how's it going?"}], model=self.model, max_tokens=self.max_tokens, temperature=self.temperature, seed=self.seed, response_model=None, - retries=retries, + max_retries=retries, # extra inference params passed to this runtime **self.model_extra, ) + return response class LiteLLMChatRuntime(InstructorClientMixin, Runtime): @@ -232,6 +245,11 @@ class LiteLLMChatRuntime(InstructorClientMixin, Runtime): model_config = ConfigDict(extra="allow") + @property + def canonical_model_provider_string(self): + """provider_name/model_name""" + return self.get_canonical_model_provider_string(self.model) + def get_llm_response(self, messages: List[Dict[str, str]]) -> str: # TODO: sunset this method in favor of record_to_record if self.verbose: @@ -342,6 +360,11 @@ def check_concurrency(cls, value) -> int: ) return value + @property + def canonical_model_provider_string(self): + """provider_name/model_name""" + return self.get_canonical_model_provider_string(self.model) + async def batch_to_batch( self, batch: InternalDataFrame, @@ -430,83 +453,6 @@ async def record_to_record( # Extract the single row from the output DataFrame and convert it to a dictionary return output_df.iloc[0].to_dict() - @staticmethod - def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int: - user_tokens = litellm.token_counter(model=model, text=string) - # FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us. - # currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that. - system_tokens = 56 + (6 * len(output_fields)) - return user_tokens + system_tokens - - @staticmethod - def _get_completion_tokens( - candidate_model_names: List[str], - output_fields: Optional[List[str]], - provider: str, - ) -> int: - max_tokens = None - for model in candidate_model_names: - try: - max_tokens = litellm.get_model_info(model=model).get("max_tokens", None) - break - except Exception as e: - if "model isn't mapped" in str(e): - continue - else: - raise e - if not max_tokens: - raise ValueError - # extremely rough heuristic, from testing on some anecdotal examples - n_outputs = len(output_fields) if output_fields else 1 - return min(max_tokens, 4 * n_outputs) - - @classmethod - def _estimate_cost( - cls, - user_prompt: str, - model: str, - output_fields: Optional[List[str]], - provider: str, - ): - prompt_tokens = cls._get_prompt_tokens(user_prompt, model, output_fields) - # amazingly, litellm.cost_per_token refers to a hardcoded dictionary litellm.model_cost which is case-sensitive with inconsistent casing..... - # Example: 'azure_ai/deepseek-r1' vs 'azure_ai/Llama-3.3-70B-Instruct' - # so we have no way of determining the correct casing or reliably fixing it. - # we can at least try all-lowercase. - candidate_model_names = [model, model.lower()] - # ...and Azure AI Foundry openai models are not listed there, but under Azure OpenAI - if model.startswith("azure_ai/"): - candidate_model_names.append(model.replace("azure_ai/", "azure/")) - candidate_model_names.append(model.replace("azure_ai/", "azure/").lower()) - candidate_model_names = list(set(candidate_model_names)) - - completion_tokens = cls._get_completion_tokens( - candidate_model_names, output_fields, provider - ) - - prompt_cost, completion_cost = None, None - for candidate_model_name in candidate_model_names: - try: - prompt_cost, completion_cost = litellm.cost_per_token( - model=candidate_model_name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - except Exception as e: - # it also doesn't have a type to catch: - # Exception("This model isn't mapped yet. model=azure_ai/deepseek-R1, custom_llm_provider=azure_ai. Add it here - https://github.com/ BerriAI/litellm/blob/main/model_prices_and_context_window.json.") - if "model isn't mapped" in str(e): - pass - if prompt_cost is not None and completion_cost is not None: - break - - if prompt_cost is None or completion_cost is None: - raise ValueError(f"Model {model} for provider {provider} not found.") - - total_cost = prompt_cost + completion_cost - - return prompt_cost, completion_cost, total_cost - def get_cost_estimate( self, prompt: str, @@ -522,23 +468,9 @@ def get_cost_estimate( cumulative_prompt_cost = 0 cumulative_completion_cost = 0 cumulative_total_cost = 0 - # for azure, we need the canonical model name, not the deployment name - if self.model.startswith("azure/"): - messages = [{"role": "user", "content": "Hey, how's it going?"}] - response = litellm.completion( - messages=messages, - model=self.model, - max_tokens=10, - temperature=self.temperature, - seed=self.seed, - # extra inference params passed to this runtime - **self.model_extra, - ) - model = "azure/" + response.model - else: - model = self.model + model = self.canonical_model_provider_string for user_prompt in user_prompts: - prompt_cost, completion_cost, total_cost = self._estimate_cost( + prompt_cost, completion_cost, total_cost = _estimate_cost( user_prompt=user_prompt, model=model, output_fields=output_fields, @@ -622,29 +554,3 @@ async def batch_to_batch( output_df = InternalDataFrame(df_data) return output_df.set_index(batch.index) - - # TODO: cost estimate - - -def get_model_info( - provider: str, model_name: str, auth_info: Optional[dict] = None -) -> dict: - if auth_info is None: - auth_info = {} - try: - # for azure models, need to get the canonical name for the model - if provider == "azure": - dummy_completion = litellm.completion( - model=f"azure/{model_name}", - messages=[{"role": "user", "content": ""}], - max_tokens=1, - **auth_info, - ) - model_name = dummy_completion.model - model_name, provider = normalize_litellm_model_and_provider( - model_name, provider - ) - return litellm.get_model_info(model=model_name, custom_llm_provider=provider) - except Exception as err: - logger.error("Hit error when trying to get model metadata: %s", err) - return {} diff --git a/adala/utils/llm_utils.py b/adala/utils/llm_utils.py index 5b4bfd69..cb27d3e2 100644 --- a/adala/utils/llm_utils.py +++ b/adala/utils/llm_utils.py @@ -2,10 +2,8 @@ import logging import traceback import litellm -from litellm import token_counter -from collections import defaultdict from typing import Any, Dict, List, Type, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel from pydantic_core import to_jsonable_python from litellm.types.utils import Usage from tenacity import Retrying, AsyncRetrying diff --git a/adala/utils/model_info_utils.py b/adala/utils/model_info_utils.py new file mode 100644 index 00000000..910adc3f --- /dev/null +++ b/adala/utils/model_info_utils.py @@ -0,0 +1,143 @@ +import re +import litellm +import logging +from typing import Optional, List + +logger = logging.getLogger(__name__) + + +def normalize_litellm_model_and_provider(model_name: str, provider: str): + """ + When using litellm.get_model_info() some models are accessed with their provider prefix + while others are not. + + This helper function contains logic which normalizes this for supported providers + """ + if "/" in model_name: + model_name = model_name.split("/", 1)[1] + provider = provider.lower() + # TODO: move this logic to LSE, this is the last place Adala needs to be updated when adding a provider connection + if provider == "vertexai": + provider = "vertex_ai" + if provider == "azureopenai": + provider = "azure" + if provider == "azureaifoundry": + provider = "azure_ai" + + return model_name, provider + + +def normalize_canonical_model_name(model: str) -> str: + """Strip date suffix from model name if present at the end (e.g. gpt-4-0613 -> gpt-4)""" + # We only know that this works for models hosted on azure openai, azure foundry, and openai + # 'gpt-4-2024-04-01' + if re.search(r"-\d{4}-\d{2}-\d{2}$", model): + model = re.sub(r"-\d{4}-\d{2}-\d{2}$", "", model) + # 'gpt-4-0613' + elif re.search(r"-\d{4}$", model): + model = re.sub(r"-\d{4}$", "", model) + return model + + +class NoModelsFoundError(ValueError): + """Raised when a model cannot be found in litellm's model map""" + + pass + + +def match_model_provider_string(model: str) -> str: + """Given a string of the form 'provider/model', return the 'provider/model' as listed in litellm's model map""" + # NOTE: if needed, can pass api_base and api_key into this function for additional hints + model_name, provider, _, _ = litellm.get_llm_provider(model) + + # amazingly, litellm.cost_per_token refers to a hardcoded dictionary litellm.model_cost which is case-sensitive with inconsistent casing..... + # Example: 'azure_ai/deepseek-r1' vs 'azure_ai/Llama-3.3-70B-Instruct' + lowercase_to_canonical_case = { + k.lower(): k for k in litellm.models_by_provider[provider] + } + candidate_model_names = [] + for name in [model_name, normalize_canonical_model_name(model_name)]: + candidate_model_names.append("/".join([provider, name.lower()])) + # ...and Azure AI Foundry openai models are not listed there, but under Azure OpenAI + if provider == "azure_ai": + for model in candidate_model_names: + candidate_model_names.append(model.replace("azure_ai/", "azure/")) + matched_models = set(candidate_model_names) & set(lowercase_to_canonical_case) + if len(matched_models) == 0: + raise NoModelsFoundError(model) + if len(matched_models) > 1: + logger.warning(f"Multiple models found for {model}: {matched_models}") + return lowercase_to_canonical_case[matched_models.pop()] + + +def get_model_info( + provider: str, model_name: str, auth_info: Optional[dict] = None +) -> dict: + if auth_info is None: + auth_info = {} + try: + # for azure models, need to get the canonical name for the model + if provider == "azure": + dummy_completion = litellm.completion( + model=f"azure/{model_name}", + messages=[{"role": "user", "content": ""}], + max_tokens=1, + **auth_info, + ) + model_name = dummy_completion.model + model_name, provider = normalize_litellm_model_and_provider( + model_name, provider + ) + return litellm.get_model_info(model=model_name, custom_llm_provider=provider) + except Exception as err: + logger.error("Hit error when trying to get model metadata: %s", err) + return {} + + +def _get_prompt_tokens(string: str, model: str, output_fields: List[str]) -> int: + user_tokens = litellm.token_counter(model=model, text=string) + # FIXME surprisingly difficult to get function call tokens, and doesn't add a ton of value, so hard-coding until something like litellm supports doing this for us. + # currently seems like we'd need to scrape the instructor logs to get the function call info, then use (at best) an openai-specific 3rd party lib to get a token estimate from that. + system_tokens = 56 + (6 * len(output_fields)) + return user_tokens + system_tokens + + +def _get_completion_tokens( + model: str, + output_fields: Optional[List[str]], +) -> int: + max_tokens = litellm.get_model_info(model=model).get("max_tokens", None) + if not max_tokens: + raise ValueError(f"Model {model} has no max tokens.") + # extremely rough heuristic, from testing on some anecdotal examples + n_outputs = len(output_fields) if output_fields else 1 + return min(max_tokens, 4 * n_outputs) + + +def _estimate_cost( + user_prompt: str, + model: str, + output_fields: Optional[List[str]], + provider: str, +): + try: + prompt_tokens = _get_prompt_tokens(user_prompt, model, output_fields) + + completion_tokens = _get_completion_tokens(model, output_fields) + + prompt_cost, completion_cost = litellm.cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + except Exception as e: + # Missing model exception doesn't have a type to catch: + # Exception("This model isn't mapped yet. model=azure_ai/deepseek-R1, custom_llm_provider=azure_ai. Add it here - https://github.com/ BerriAI/litellm/blob/main/model_prices_and_context_window.json.") + if "model isn't mapped" in str(e): + raise ValueError(f"Model {model} for provider {provider} not found.") + else: + raise e + + total_cost = prompt_cost + completion_cost + + return prompt_cost, completion_cost, total_cost diff --git a/server/app.py b/server/app.py index 1162a777..91ffc6a2 100644 --- a/server/app.py +++ b/server/app.py @@ -571,7 +571,7 @@ class ModelMetadataResponse(BaseModel): @app.post("/model-metadata", response_model=Response[ModelMetadataResponse]) async def model_metadata(request: ModelMetadataRequest): - from adala.runtimes._litellm import get_model_info + from adala.utils.model_info_utils import get_model_info resp = { "model_metadata": {