Skip to content

380 add organization and headers for open ai #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions adalflow/adalflow/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,24 @@ def __init__(
input_type: Literal["text", "messages"] = "text",
base_url: str = "https://api.openai.com/v1/",
env_api_key_name: str = "OPENAI_API_KEY",
organization: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.

Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
base_url (str): The API base URL to use when initializing the client.
env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`.
organization (Optional[str], optional): OpenAI organization key. Defaults to None.
headers (Optional[Dict[str, str]], optional): Additional headers to include in API requests. Defaults to None.
"""
super().__init__()
self._api_key = api_key
self.base_url = base_url
self._env_api_key_name = env_api_key_name
self.organization = organization
self.headers = headers or {}
self.sync_client = self.init_sync_client()
self.async_client = None # only initialize if the async call is called
self.chat_completion_parser = (
Expand All @@ -191,15 +197,25 @@ def init_sync_client(self):
raise ValueError(
f"Environment variable {self._env_api_key_name} must be set"
)
return OpenAI(api_key=api_key, base_url=self.base_url)
return OpenAI(
api_key=api_key,
base_url=self.base_url,
organization=self.organization,
default_headers=self.headers,
)

def init_async_client(self):
api_key = self._api_key or os.getenv(self._env_api_key_name)
if not api_key:
raise ValueError(
f"Environment variable {self._env_api_key_name} must be set"
)
return AsyncOpenAI(api_key=api_key, base_url=self.base_url)
return AsyncOpenAI(
api_key=api_key,
base_url=self.base_url,
organization=self.organization,
default_headers=self.headers,
)

# def _parse_chat_completion(self, completion: ChatCompletion) -> "GeneratorOutput":
# # TODO: raw output it is better to save the whole completion as a source of truth instead of just the message
Expand Down Expand Up @@ -588,4 +604,3 @@ def _prepare_image_content(
)
resopnse = openai_llm(prompt_kwargs={"input_str": "What is LLM?"})
print(resopnse)

111 changes: 111 additions & 0 deletions adalflow/tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,117 @@ def test_from_dict_to_dict(self):
new_client = OpenAIClient.from_dict(client_dict)
self.assertEqual(new_client.to_dict(), client_dict)

@patch("adalflow.components.model_client.openai_client.OpenAI")
def test_init_sync_client_with_headers_and_organization(self, MockOpenAI):
headers = {"Custom-Header": "CustomValue"}
organization = "test-organization"

# First call happens during __init__
client = OpenAIClient(
api_key="fake_api_key",
headers=headers,
organization=organization,
)

# Clear previous calls so we only test the explicit one below
MockOpenAI.reset_mock()

# Now call init_sync_client explicitly to trigger the OpenAI call
_ = client.init_sync_client()

# Assert OpenAI was called with correct parameters
MockOpenAI.assert_called_once_with(
api_key="fake_api_key",
base_url="https://api.openai.com/v1/",
organization=organization,
default_headers=headers,
)

@patch("adalflow.components.model_client.openai_client.AsyncOpenAI")
async def test_init_async_client_with_headers_and_organization(
self, MockAsyncOpenAI
):
headers = {"Custom-Header": "CustomValue"}
organization = "test-organization"

# Manually assign an AsyncMock to the return value
mock_async_client = AsyncMock()
MockAsyncOpenAI.return_value = mock_async_client

client = OpenAIClient(
api_key="fake_api_key",
headers=headers,
organization=organization,
)

async_client = client.init_async_client() # Do NOT await here

MockAsyncOpenAI.assert_called_once_with(
api_key="fake_api_key",
base_url="https://api.openai.com/v1/",
organization=organization,
default_headers=headers,
)
self.assertEqual(async_client, mock_async_client)

@patch("adalflow.components.model_client.openai_client.OpenAI")
def test_call_with_custom_headers_and_organization(self, MockOpenAI):
# Test that headers and organization are passed during a call
headers = {"Custom-Header": "CustomValue"}
organization = "test-organization"
mock_sync_client = Mock()
MockOpenAI.return_value = mock_sync_client

client = OpenAIClient(
api_key="fake_api_key",
headers=headers,
organization=organization,
)
client.sync_client = mock_sync_client

# Mock the API call
mock_sync_client.chat.completions.create = Mock(return_value=self.mock_response)

# Call the method
result = client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)

# Assertions
mock_sync_client.chat.completions.create.assert_called_once_with(
**self.api_kwargs
)
self.assertEqual(result, self.mock_response)

@patch("adalflow.components.model_client.openai_client.AsyncOpenAI")
async def test_acall_with_custom_headers_and_organization(self, MockAsyncOpenAI):
# Test that headers and organization are passed during an async call
headers = {"Custom-Header": "CustomValue"}
organization = "test-organization"
mock_async_client = AsyncMock()
MockAsyncOpenAI.return_value = mock_async_client

client = OpenAIClient(
api_key="fake_api_key",
headers=headers,
organization=organization,
)
client.async_client = mock_async_client

# Mock the API call
mock_async_client.chat.completions.create = AsyncMock(
return_value=self.mock_response
)

# Call the method
result = await client.acall(
api_kwargs=self.api_kwargs, model_type=ModelType.LLM
)

# Assertions
mock_async_client.chat.completions.create.assert_awaited_once_with(
**self.api_kwargs
)
self.assertEqual(result, self.mock_response)


if __name__ == "__main__":
unittest.main()
Loading