Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions pointblank/draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ class DraftValidation:
starting point for validating a table. This can be useful when you have a new table and you
want to get a sense of how to validate it (and adjustments could always be made later). The
`DraftValidation` class uses the `chatlas` package to draft a validation plan for a given table
using an LLM from either the `"anthropic"`, `"openai"`, `"ollama"` or `"bedrock"` provider. You
can install all requirements for the class through an optional 'generate' install of Pointblank
via `pip install pointblank[generate]`.
using an LLM from the `"anthropic"`, `"openai"`, `"ollama"`, `"bedrock"`, or `"azure-openai"`
provider. You can install all requirements for the class through an optional 'generate' install
of Pointblank via `pip install pointblank[generate]`.

:::{.callout-warning}
The `DraftValidation` class is still experimental. Please report any issues you encounter in
Expand All @@ -38,7 +38,8 @@ class DraftValidation:
model
The model to be used. This should be in the form of `provider:model` (e.g.,
`"anthropic:claude-opus-4-6"`). Supported providers are `"anthropic"`, `"openai"`,
`"ollama"`, and `"bedrock"`.
`"ollama"`, `"bedrock"`, and `"azure-openai"`. For `"azure-openai"`, the value after the
colon is the Azure *deployment id*, not an OpenAI model id.
api_key
The API key to be used for the model.
verify_ssl
Expand All @@ -61,9 +62,15 @@ class DraftValidation:
- `"openai"` (OpenAI)
- `"ollama"` (Ollama)
- `"bedrock"` (Amazon Bedrock)
- `"azure-openai"` (Azure OpenAI)

The model name should be the specific model to be used from the provider. Model names are
subject to change so consult the provider's documentation for the most up-to-date model names.
For `"azure-openai"`, the value after the colon is the Azure *deployment id* (the name you
assigned when deploying the model in your Azure OpenAI resource). It also requires the
environment variables `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT` (e.g.,
`https://<resource>.openai.azure.com`), and `OPENAI_API_VERSION` (e.g., `"2024-06-01"`) to
be set.

Notes on Authentication
-----------------------
Expand Down Expand Up @@ -251,7 +258,9 @@ def __post_init__(self) -> None:
)

# Read the API/examples text from a file
with files("pointblank.data").joinpath("api-docs.txt").open() as f: # pragma: no cover
with (
files("pointblank.data").joinpath("api-docs.txt").open(encoding="utf-8") as f
): # pragma: no cover
api_and_examples_text = f.read()

# Get the model name from the `model` value
Expand Down Expand Up @@ -389,6 +398,41 @@ def __post_init__(self) -> None:
kwargs={"http_client": http_client},
)

if provider == "azure-openai": # pragma: no cover
try:
import openai # noqa
except ImportError: # pragma: no cover
raise ImportError( # pragma: no cover
"The `openai` package is required to use the `DraftValidation` class with "
"the `azure-openai` provider. Please install it using `pip install openai`."
)

import os

endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_version = os.getenv("OPENAI_API_VERSION")
if not endpoint:
raise ValueError(
"AZURE_OPENAI_ENDPOINT environment variable must be set to use "
"the 'azure-openai' provider."
)
if not api_version:
raise ValueError(
"OPENAI_API_VERSION environment variable must be set to use "
"the 'azure-openai' provider (e.g. '2024-06-01')."
)

from chatlas import ChatAzureOpenAI # pragma: no cover

chat = ChatAzureOpenAI( # pragma: no cover
endpoint=endpoint,
deployment_id=model_name,
api_version=api_version,
api_key=self.api_key,
system_prompt="You are a terse assistant and a Python expert.",
kwargs={"http_client": http_client},
)

self.response = str(chat.chat(prompt, stream=False, echo="none")) # pragma: no cover

def __str__(self) -> str:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,23 @@ def test_draft_fail_invalid_provider():

with pytest.raises(ValueError):
DraftValidation(data=small_table, model="invalid:model")


def test_draft_fail_azure_openai_missing_endpoint(monkeypatch):
pytest.importorskip("openai")
pytest.importorskip("chatlas")
monkeypatch.delenv("AZURE_OPENAI_ENDPOINT", raising=False)
monkeypatch.setenv("OPENAI_API_VERSION", "2024-06-01")
small_table = load_dataset(dataset="small_table")
with pytest.raises(ValueError, match="AZURE_OPENAI_ENDPOINT"):
DraftValidation(data=small_table, model="azure-openai:my-deployment")


def test_draft_fail_azure_openai_missing_api_version(monkeypatch):
pytest.importorskip("openai")
pytest.importorskip("chatlas")
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", "https://example.openai.azure.com")
monkeypatch.delenv("OPENAI_API_VERSION", raising=False)
small_table = load_dataset(dataset="small_table")
with pytest.raises(ValueError, match="OPENAI_API_VERSION"):
DraftValidation(data=small_table, model="azure-openai:my-deployment")
10 changes: 10 additions & 0 deletions user_guide/02-advanced-validation/04-draft-validation.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ The `DraftValidation` class supports multiple LLM providers:
- **OpenAI** (GPT models)
- **Ollama** (local LLMs)
- **Amazon Bedrock** (AWS-hosted models)
- **Azure OpenAI** (OpenAI models deployed on Azure)

Each provider has different capabilities and performance characteristics, but all can be used to
generate validation plans through a consistent interface.
Expand Down Expand Up @@ -171,6 +172,11 @@ You can also store API keys in a `.env` file in your project's root directory:
# Contents of .env file
ANTHROPIC_API_KEY=your_anthropic_api_key_here
OPENAI_API_KEY=your_openai_api_key_here

# For Azure OpenAI, three variables are required:
AZURE_OPENAI_API_KEY=your_azure_openai_api_key_here
AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com
OPENAI_API_VERSION=2025-03-01-preview
```

If your API keys have standard names (like `ANTHROPIC_API_KEY` or `OPENAI_API_KEY`),
Expand Down Expand Up @@ -280,6 +286,10 @@ pb.DraftValidation(data=data, model="ollama:llama3:latest")

# Using Amazon Bedrock
pb.DraftValidation(data=data, model="bedrock:anthropic.claude-3-sonnet-20240229-v1:0")

# Using Azure OpenAI (the value after the colon is the Azure deployment id, not an OpenAI model id;
# requires AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and OPENAI_API_VERSION env vars)
pb.DraftValidation(data=data, model="azure-openai:my-gpt4-deployment")
```

### Model Performance and Privacy
Expand Down
Loading