Skip to content

fix: DIA-2053: Support vision by default in Azure models #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
BadRequestError,
NotFoundError,
APIConnectionError,
APIError
APIError,
)
from litellm.types.utils import Usage
import instructor
Expand Down Expand Up @@ -518,7 +518,8 @@ def init_runtime(self) -> "Runtime":
super().init_runtime()
# Only running this supports_vision check for non-vertex models, since its based on a static JSON file in
# litellm which was not up to date. Will be soon in next release - should update this
if not self.model.startswith("vertex_ai"):
# Added azure to the exception list bc https://github.com/BerriAI/litellm/issues/6524
if not self.model.startswith(("vertex_ai", "azure")):
model_name = self.model
if not litellm.supports_vision(model_name):
raise ValueError(f"Model {self.model} does not support vision")
Expand Down
6 changes: 5 additions & 1 deletion adala/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def handle_llm_exception(
n_attempts = retries.stop.max_attempt_number
if prompt_token_count is None:
prompt_token_count = token_counter(model=model, messages=messages[:-1])
if type(e).__name__ in {"APIError", "AuthenticationError", "APIConnectionError"}:
if type(e).__name__ in {
"APIError",
"AuthenticationError",
"APIConnectionError",
}:
prompt_tokens = 0
else:
prompt_tokens = n_attempts * prompt_token_count
Expand Down