Skip to content

Commit 9ed5a30

Browse files
authored
feat(openai): add Azure OpenAI support in OpenAIChatModel class (#932)
1 parent ea8cb9d commit 9ed5a30

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

src/agentscope/model/_openai_model.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
stream: bool = True,
7474
reasoning_effort: Literal["low", "medium", "high"] | None = None,
7575
organization: str = None,
76+
client_type: Literal["openai", "azure"] = "openai",
7677
client_kwargs: dict[str, JSONSerializableObject] | None = None,
7778
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
7879
**kwargs: Any,
@@ -96,6 +97,8 @@ def __init__(
9697
organization (`str`, default `None`):
9798
The organization ID for OpenAI API. If not specified, it will
9899
be read from the environment variable `OPENAI_ORGANIZATION`.
100+
client_type (`Literal["openai", "azure"]`, default `openai`):
101+
Selects which OpenAI-compatible client to initialize.
99102
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
100103
optional):
101104
The extra keyword arguments to initialize the OpenAI client.
@@ -134,11 +137,23 @@ def __init__(
134137

135138
import openai
136139

137-
self.client = openai.AsyncClient(
138-
api_key=api_key,
139-
organization=organization,
140-
**(client_kwargs or {}),
141-
)
140+
if client_type not in ("openai", "azure"):
141+
raise ValueError(
142+
"Invalid client_type. Supported values: 'openai', 'azure'.",
143+
)
144+
145+
if client_type == "azure":
146+
self.client = openai.AsyncAzureOpenAI(
147+
api_key=api_key,
148+
organization=organization,
149+
**(client_kwargs or {}),
150+
)
151+
else:
152+
self.client = openai.AsyncClient(
153+
api_key=api_key,
154+
organization=organization,
155+
**(client_kwargs or {}),
156+
)
142157

143158
self.reasoning_effort = reasoning_effort
144159
self.generate_kwargs = generate_kwargs or {}

0 commit comments

Comments
 (0)