Skip to content

Commit 9fb48c1

Browse files
committed
added optional openai base_url argument
1 parent f4da290 commit 9fb48c1

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
* **Migrate Jira Source connector from V1 to V2**
66
* **Add Jira Source connector integration and unit tests**
7+
* **Added option for custom OpenAI baseurl in EmbedderConfig**
78

89
## 0.5.9
910

unstructured_ingest/embed/openai.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import TYPE_CHECKING
2+
from typing import TYPE_CHECKING, Optional
33

44
from pydantic import Field, SecretStr
55

@@ -26,6 +26,7 @@
2626
class OpenAIEmbeddingConfig(EmbeddingConfig):
2727
api_key: SecretStr
2828
embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
29+
base_url: Optional[str] = Field(default=None)
2930

3031
def wrap_error(self, e: Exception) -> Exception:
3132
if is_internal_error(e=e):
@@ -57,13 +58,13 @@ def wrap_error(self, e: Exception) -> Exception:
5758
def get_client(self) -> "OpenAI":
5859
from openai import OpenAI
5960

60-
return OpenAI(api_key=self.api_key.get_secret_value())
61+
return OpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
6162

6263
@requires_dependencies(["openai"], extras="openai")
6364
def get_async_client(self) -> "AsyncOpenAI":
6465
from openai import AsyncOpenAI
6566

66-
return AsyncOpenAI(api_key=self.api_key.get_secret_value())
67+
return AsyncOpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
6768

6869

6970
@dataclass

unstructured_ingest/v2/processes/embedder.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ class EmbedderConfig(BaseModel):
5252
embedding_azure_api_version: Optional[str] = Field(
5353
description="Azure API version", default=None
5454
)
55+
embedding_openai_endpoint: Optional[str] = Field(
56+
default=None,
57+
description="Your custom OpenAI base url, "
58+
"e.g. `https://custom-openai-deployment.com/`",
59+
)
5560

5661
def get_huggingface_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
5762
from unstructured_ingest.embed.huggingface import (
@@ -66,7 +71,16 @@ def get_huggingface_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEnco
6671
def get_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
6772
from unstructured_ingest.embed.openai import OpenAIEmbeddingConfig, OpenAIEmbeddingEncoder
6873

69-
return OpenAIEmbeddingEncoder(config=OpenAIEmbeddingConfig.model_validate(embedding_kwargs))
74+
config_kwargs = {
75+
"api_key": self.embedding_api_key,
76+
"base_url": self.embedding_openai_endpoint,
77+
}
78+
if model_name := self.embedding_model_name:
79+
config_kwargs["model_name"] = model_name
80+
81+
return OpenAIEmbeddingEncoder(
82+
config=OpenAIEmbeddingConfig.model_validate(config_kwargs)
83+
)
7084

7185
def get_azure_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
7286
from unstructured_ingest.embed.azure_openai import (

0 commit comments

Comments
 (0)