Skip to content

Commit 84aef77

Browse files
authored
Merge pull request #385 from howardbaik/main
Add AzureOpenAI as a model provider
2 parents 230f130 + 2482c3e commit 84aef77

6 files changed

Lines changed: 94 additions & 5 deletions

File tree

pointblank/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
"anthropic",
142142
"ollama",
143143
"bedrock",
144+
"azure-openai",
144145
]
145146

146147
TABLE_TYPE_STYLES = {

pointblank/_utils_ai.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,41 @@ def _create_chat_instance(
172172
kwargs={"http_client": http_client},
173173
)
174174

175+
elif provider == "azure-openai": # pragma: no cover
176+
try:
177+
import openai # noqa
178+
except ImportError:
179+
raise ImportError(
180+
"The `openai` package is required to use AI validation with "
181+
"`azure-openai`. Please install it using `pip install openai`."
182+
)
183+
184+
import os
185+
186+
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
187+
api_version = os.getenv("OPENAI_API_VERSION")
188+
if not endpoint:
189+
raise ValueError(
190+
"AZURE_OPENAI_ENDPOINT environment variable must be set to use "
191+
"the 'azure-openai' provider."
192+
)
193+
if not api_version:
194+
raise ValueError(
195+
"OPENAI_API_VERSION environment variable must be set to use "
196+
"the 'azure-openai' provider (e.g. '2024-06-01')."
197+
)
198+
199+
from chatlas import ChatAzureOpenAI
200+
201+
chat = ChatAzureOpenAI(
202+
endpoint=endpoint,
203+
deployment_id=model_name,
204+
api_version=api_version,
205+
api_key=api_key,
206+
system_prompt=system_prompt,
207+
kwargs={"http_client": http_client},
208+
)
209+
175210
else:
176211
raise ValueError(f"Unsupported provider: {provider}")
177212

pointblank/validate.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10929,9 +10929,10 @@ def prompt(
1092910929
model
1093010930
The model to be used. This should be in the form of `provider:model` (e.g.,
1093110931
`"anthropic:claude-opus-4-6"`). Supported providers are `"anthropic"`, `"openai"`,
10932-
`"ollama"`, and `"bedrock"`. The model name should be the specific model to be used from
10933-
the provider. Model names are subject to change so consult the provider's documentation
10934-
for the most up-to-date model names.
10932+
`"ollama"`, `"bedrock"`, and `"azure-openai"`. The model name should be the specific
10933+
model to be used from the provider (for `"azure-openai"`, the value after the colon is
10934+
the Azure *deployment id*). Model names are subject to change so consult the provider's
10935+
documentation for the most up-to-date model names.
1093510936
batch_size
1093610937
Number of rows to process in each batch. Larger batches are more efficient but may hit
1093710938
API limits. Default is `1000`.
@@ -10985,10 +10986,13 @@ def prompt(
1098510986
- `"openai"` (OpenAI)
1098610987
- `"ollama"` (Ollama)
1098710988
- `"bedrock"` (Amazon Bedrock)
10989+
- `"azure-openai"` (Azure OpenAI)
1098810990

1098910991
The model name should be the specific model to be used from the provider. Model names are
1099010992
subject to change so consult the provider's documentation for the most up-to-date model
10991-
names.
10993+
names. For `"azure-openai"`, the value after the colon is the Azure *deployment id* (the
10994+
name you assigned when deploying the model in your Azure OpenAI resource), not an OpenAI
10995+
model id.
1099210996

1099310997
Notes on Authentication
1099410998
-----------------------
@@ -11019,6 +11023,8 @@ def prompt(
1101911023
- **Anthropic**: set `ANTHROPIC_API_KEY` environment variable or create `.env` file
1102011024
- **Ollama**: no API key required, just ensure Ollama is running locally
1102111025
- **Bedrock**: configure AWS credentials through standard AWS methods
11026+
- **Azure OpenAI**: set `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT` (e.g.,
11027+
`https://<resource>.openai.azure.com`), and `OPENAI_API_VERSION` (e.g., `"2024-06-01"`)
1102211028

1102311029
AI Validation Process
1102411030
---------------------

tests/test__utils_ai.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,46 @@ def test_create_chat_instance_invalid_provider():
119119
_create_chat_instance("invalid", "model")
120120

121121

122+
def test_create_chat_instance_azure_openai_missing_endpoint(monkeypatch):
123+
"""Azure OpenAI provider raises if AZURE_OPENAI_ENDPOINT is unset."""
124+
pytest.importorskip("openai")
125+
monkeypatch.delenv("AZURE_OPENAI_ENDPOINT", raising=False)
126+
monkeypatch.setenv("OPENAI_API_VERSION", "2024-06-01")
127+
with pytest.raises(ValueError, match="AZURE_OPENAI_ENDPOINT"):
128+
_create_chat_instance("azure-openai", "my-deployment")
129+
130+
131+
def test_create_chat_instance_azure_openai_missing_api_version(monkeypatch):
132+
"""Azure OpenAI provider raises if OPENAI_API_VERSION is unset."""
133+
pytest.importorskip("openai")
134+
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://example.openai.azure.com")
135+
monkeypatch.delenv("OPENAI_API_VERSION", raising=False)
136+
with pytest.raises(ValueError, match="OPENAI_API_VERSION"):
137+
_create_chat_instance("azure-openai", "my-deployment")
138+
139+
140+
def test_create_chat_instance_azure_openai_forwards_params(monkeypatch):
141+
"""Azure OpenAI provider forwards env vars + deployment id to ChatAzureOpenAI."""
142+
pytest.importorskip("openai")
143+
chatlas = pytest.importorskip("chatlas")
144+
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://example.openai.azure.com")
145+
monkeypatch.setenv("OPENAI_API_VERSION", "2024-06-01")
146+
147+
sentinel = object()
148+
with patch.object(chatlas, "ChatAzureOpenAI", return_value=sentinel) as mock_cls:
149+
result = _create_chat_instance("azure-openai", "my-deployment", api_key="secret")
150+
151+
assert result is sentinel
152+
mock_cls.assert_called_once()
153+
kwargs = mock_cls.call_args.kwargs
154+
assert kwargs["endpoint"] == "https://example.openai.azure.com"
155+
assert kwargs["deployment_id"] == "my-deployment"
156+
assert kwargs["api_version"] == "2024-06-01"
157+
assert kwargs["api_key"] == "secret"
158+
assert "system_prompt" in kwargs
159+
assert "http_client" in kwargs["kwargs"]
160+
161+
122162
# ============================================================================
123163
# Test BatchConfig
124164
# ============================================================================

tests/test_prompt_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def test_prompt_inactive():
410410
("openai", "gpt-4o-mini"),
411411
("ollama", "llama2"),
412412
("bedrock", "anthropic.claude-3-sonnet-20240229-v1:0"),
413+
("azure-openai", "my-gpt4-deployment"),
413414
],
414415
)
415416
def test_prompt_with_different_providers(provider, model):

user_guide/01-validation-plan/02-validation-methods.qmd

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,13 @@ The `columns_subset=` parameter lets you specify which columns to include in the
611611
improving performance and reducing API costs by only sending relevant data to the LLM.
612612

613613
**Note:** To use [`Validate.prompt()`](`Validate.prompt`), you need to have the appropriate API credentials configured
614-
for your chosen LLM provider (Anthropic, OpenAI, Ollama, or AWS Bedrock).
614+
for your chosen LLM provider (Anthropic, OpenAI, Ollama, AWS Bedrock, or Azure OpenAI).
615+
616+
For **Azure OpenAI**, use `model="azure-openai:<deployment_id>"` where `<deployment_id>` is the
617+
name you assigned when deploying the model in your Azure OpenAI resource. In addition to
618+
`AZURE_OPENAI_API_KEY`, you must set `AZURE_OPENAI_ENDPOINT` (e.g.,
619+
`https://<resource>.openai.azure.com`) and `OPENAI_API_VERSION` (e.g., `"2024-06-01"`) in your
620+
environment.
615621

616622
## 5. Aggregate Validations
617623

0 commit comments

Comments
 (0)