-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathtest_azure.py
More file actions
164 lines (132 loc) · 6.54 KB
/
test_azure.py
File metadata and controls
164 lines (132 loc) · 6.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import pytest
from inline_snapshot import snapshot
from pytest_mock import MockerFixture
from pydantic_ai import BinaryContent
from pydantic_ai._json_schema import InlineDefsJsonSchemaTransformer
from pydantic_ai.agent import Agent
from pydantic_ai.exceptions import UserError
from pydantic_ai.profiles.cohere import cohere_model_profile
from pydantic_ai.profiles.deepseek import deepseek_model_profile
from pydantic_ai.profiles.grok import grok_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.mistral import mistral_model_profile
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile
from ..conftest import try_import
with try_import() as imports_successful:
from openai import AsyncAzureOpenAI
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.azure import AzureProvider
pytestmark = [
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
pytest.mark.vcr,
pytest.mark.anyio,
]
def test_azure_provider():
provider = AzureProvider(
azure_endpoint='https://project-id.openai.azure.com/',
api_version='2023-03-15-preview',
api_key='1234567890',
)
assert isinstance(provider, AzureProvider)
assert provider.name == 'azure'
assert provider.base_url == snapshot('https://project-id.openai.azure.com/openai/')
assert isinstance(provider.client, AsyncAzureOpenAI)
def test_azure_provider_with_openai_model():
model = OpenAIChatModel(
model_name='gpt-4o',
provider=AzureProvider(
azure_endpoint='https://project-id.openai.azure.com/',
api_version='2023-03-15-preview',
api_key='1234567890',
),
)
assert isinstance(model, OpenAIChatModel)
assert isinstance(model.client, AsyncAzureOpenAI)
def test_azure_provider_with_azure_openai_client():
client = AsyncAzureOpenAI(
api_version='2024-12-01-preview',
azure_endpoint='https://project-id.openai.azure.com/',
api_key='1234567890',
)
provider = AzureProvider(openai_client=client)
assert isinstance(provider.client, AsyncAzureOpenAI)
async def test_azure_provider_call(allow_model_requests: None):
api_key = os.getenv('AZURE_OPENAI_API_KEY', '1234567890')
api_version = os.getenv('AZURE_OPENAI_API_VERSION', '2024-12-01-preview')
provider = AzureProvider(
api_key=api_key,
azure_endpoint='https://pydanticai7521574644.openai.azure.com/',
api_version=api_version,
)
model = OpenAIChatModel(model_name='gpt-4o', provider=provider)
agent = Agent(model)
result = await agent.run('What is the capital of France?')
assert result.output == snapshot('The capital of France is **Paris**.')
def test_azure_provider_model_profile(mocker: MockerFixture):
provider = AzureProvider(
azure_endpoint='https://project-id.openai.azure.com/',
api_version='2023-03-15-preview',
api_key='1234567890',
)
ns = 'pydantic_ai.providers.azure'
meta_model_profile_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile)
deepseek_model_profile_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile)
mistral_model_profile_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile)
cohere_model_profile_mock = mocker.patch(f'{ns}.cohere_model_profile', wraps=cohere_model_profile)
grok_model_profile_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile)
openai_model_profile_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile)
meta_profile = provider.model_profile('Llama-4-Scout-17B-16E')
meta_model_profile_mock.assert_called_with('llama-4-scout-17b-16e')
assert meta_profile is not None
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
meta_profile = provider.model_profile('Meta-Llama-3.1-405B-Instruct')
meta_model_profile_mock.assert_called_with('llama-3.1-405b-instruct')
assert meta_profile is not None
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
deepseek_profile = provider.model_profile('DeepSeek-R1')
deepseek_model_profile_mock.assert_called_with('deepseek-r1')
assert deepseek_profile is not None
assert deepseek_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
mistral_profile = provider.model_profile('mistral-medium-2505')
mistral_model_profile_mock.assert_called_with('mistral-medium-2505')
assert mistral_profile is not None
assert mistral_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
mistral_profile = provider.model_profile('mistralai-Mixtral-8x22B-Instruct-v0-1')
mistral_model_profile_mock.assert_called_with('mixtral-8x22b-instruct-v0-1')
assert mistral_profile is not None
assert mistral_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
cohere_profile = provider.model_profile('cohere-command-a')
cohere_model_profile_mock.assert_called_with('command-a')
assert cohere_profile is not None
assert cohere_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
grok_profile = provider.model_profile('grok-3')
grok_model_profile_mock.assert_called_with('grok-3')
assert grok_profile is not None
assert grok_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
openai_profile = provider.model_profile('o4-mini')
openai_model_profile_mock.assert_called_with('o4-mini')
assert openai_profile is not None
assert openai_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
unknown_profile = provider.model_profile('unknown-model')
openai_model_profile_mock.assert_called_with('unknown-model')
assert unknown_profile is not None
assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
async def test_azure_document_input_not_supported(allow_model_requests: None):
provider = AzureProvider(
azure_endpoint='https://project-id.openai.azure.com/',
api_version='2023-03-15-preview',
api_key='1234567890',
)
model = OpenAIChatModel(model_name='gpt-4o', provider=provider)
agent = Agent(model)
with pytest.raises(
UserError,
match="Azure's Chat Completions API does not support document input.*OpenAIResponsesModel",
):
await agent.run(
[
'Summarize this document',
BinaryContent(data=b'%PDF-1.4 test', media_type='application/pdf'),
]
)