-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathazure.py
More file actions
146 lines (123 loc) · 5.97 KB
/
azure.py
File metadata and controls
146 lines (123 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from __future__ import annotations as _annotations
import os
from typing import overload
import httpx
from openai import AsyncOpenAI
from pydantic_ai import ModelProfile
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles.cohere import cohere_model_profile
from pydantic_ai.profiles.deepseek import deepseek_model_profile
from pydantic_ai.profiles.grok import grok_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
from pydantic_ai.providers import Provider
try:
from openai import AsyncAzureOpenAI
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `openai` package to use the Azure provider, '
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
) from _import_error
class AzureProvider(Provider[AsyncOpenAI]):
"""Provider for Azure OpenAI API.
See <https://azure.microsoft.com/en-us/products/ai-foundry> for more information.
"""
@property
def name(self) -> str:
return 'azure'
@property
def base_url(self) -> str:
assert self._base_url is not None
return self._base_url
@property
def client(self) -> AsyncOpenAI:
return self._client
def model_profile(self, model_name: str) -> ModelProfile | None:
model_name = model_name.lower()
prefix_to_profile = {
'llama': meta_model_profile,
'meta-': meta_model_profile,
'deepseek': deepseek_model_profile,
'mistralai-': mistral_model_profile,
'mistral': mistral_model_profile,
'cohere-': cohere_model_profile,
'grok': grok_model_profile,
}
for prefix, profile_func in prefix_to_profile.items():
if model_name.startswith(prefix):
if prefix.endswith('-'):
model_name = model_name[len(prefix) :]
profile = profile_func(model_name)
# As AzureProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
# we need to maintain that behavior unless json_schema_transformer is set explicitly
# Azure Chat Completions API doesn't support file input
return OpenAIModelProfile(
json_schema_transformer=OpenAIJsonSchemaTransformer,
openai_chat_supports_file_input=False,
).update(profile)
# OpenAI models are unprefixed
# Azure Chat Completions API doesn't support file input
return OpenAIModelProfile(openai_chat_supports_file_input=False).update(openai_model_profile(model_name))
@overload
def __init__(self, *, openai_client: AsyncAzureOpenAI) -> None: ...
@overload
def __init__(
self,
*,
azure_endpoint: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None: ...
def __init__(
self,
*,
azure_endpoint: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
openai_client: AsyncAzureOpenAI | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Azure provider.
Args:
azure_endpoint: The Azure endpoint to use for authentication, if not provided, the `AZURE_OPENAI_ENDPOINT`
environment variable will be used if available.
api_version: The API version to use for authentication, if not provided, the `OPENAI_API_VERSION`
environment variable will be used if available.
api_key: The API key to use for authentication, if not provided, the `AZURE_OPENAI_API_KEY` environment variable
will be used if available.
openai_client: An existing
[`AsyncAzureOpenAI`](https://github.com/openai/openai-python#microsoft-azure-openai)
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
if openai_client is not None:
assert azure_endpoint is None, 'Cannot provide both `openai_client` and `azure_endpoint`'
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
self._base_url = str(openai_client.base_url)
self._client = openai_client
else:
azure_endpoint = azure_endpoint or os.getenv('AZURE_OPENAI_ENDPOINT')
if not azure_endpoint:
raise UserError(
'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable'
)
if not api_key and 'AZURE_OPENAI_API_KEY' not in os.environ: # pragma: no cover
raise UserError(
'Must provide one of the `api_key` argument or the `AZURE_OPENAI_API_KEY` environment variable'
)
if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover
raise UserError(
'Must provide one of the `api_version` argument or the `OPENAI_API_VERSION` environment variable'
)
http_client = http_client or cached_async_http_client(provider='azure')
self._client = AsyncAzureOpenAI(
azure_endpoint=azure_endpoint,
api_key=api_key,
api_version=api_version,
http_client=http_client,
)
self._base_url = str(self._client.base_url)