@@ -73,6 +73,7 @@ def __init__(
7373 stream : bool = True ,
7474 reasoning_effort : Literal ["low" , "medium" , "high" ] | None = None ,
7575 organization : str = None ,
76+ client_type : Literal ["openai" , "azure" ] = "openai" ,
7677 client_kwargs : dict [str , JSONSerializableObject ] | None = None ,
7778 generate_kwargs : dict [str , JSONSerializableObject ] | None = None ,
7879 ** kwargs : Any ,
@@ -96,6 +97,8 @@ def __init__(
9697 organization (`str`, default `None`):
9798 The organization ID for OpenAI API. If not specified, it will
9899 be read from the environment variable `OPENAI_ORGANIZATION`.
100+ client_type (`Literal["openai", "azure"]`, default `openai`):
101+ Selects which OpenAI-compatible client to initialize.
99102 client_kwargs (`dict[str, JSONSerializableObject] | None`, \
100103 optional):
101104 The extra keyword arguments to initialize the OpenAI client.
@@ -134,11 +137,23 @@ def __init__(
134137
135138 import openai
136139
137- self .client = openai .AsyncClient (
138- api_key = api_key ,
139- organization = organization ,
140- ** (client_kwargs or {}),
141- )
140+ if client_type not in ("openai" , "azure" ):
141+ raise ValueError (
142+ "Invalid client_type. Supported values: 'openai', 'azure'." ,
143+ )
144+
145+ if client_type == "azure" :
146+ self .client = openai .AsyncAzureOpenAI (
147+ api_key = api_key ,
148+ organization = organization ,
149+ ** (client_kwargs or {}),
150+ )
151+ else :
152+ self .client = openai .AsyncClient (
153+ api_key = api_key ,
154+ organization = organization ,
155+ ** (client_kwargs or {}),
156+ )
142157
143158 self .reasoning_effort = reasoning_effort
144159 self .generate_kwargs = generate_kwargs or {}
0 commit comments