Skip to content

feat(embedderconfig): add embedding_openai_endpoint argument to EmbedderConfig for custom OpenAI deployments #406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

* **Migrate Jira Source connector from V1 to V2**
* **Add Jira Source connector integration and unit tests**
* **Added option for custom OpenAI baseurl in EmbedderConfig**

## 0.5.9

Expand Down
7 changes: 4 additions & 3 deletions unstructured_ingest/embed/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from pydantic import Field, SecretStr

Expand All @@ -26,6 +26,7 @@
class OpenAIEmbeddingConfig(EmbeddingConfig):
api_key: SecretStr
embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
base_url: Optional[str] = Field(default=None)

def wrap_error(self, e: Exception) -> Exception:
if is_internal_error(e=e):
Expand Down Expand Up @@ -57,13 +58,13 @@ def wrap_error(self, e: Exception) -> Exception:
def get_client(self) -> "OpenAI":
from openai import OpenAI

return OpenAI(api_key=self.api_key.get_secret_value())
return OpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)

@requires_dependencies(["openai"], extras="openai")
def get_async_client(self) -> "AsyncOpenAI":
from openai import AsyncOpenAI

return AsyncOpenAI(api_key=self.api_key.get_secret_value())
return AsyncOpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)


@dataclass
Expand Down
16 changes: 15 additions & 1 deletion unstructured_ingest/v2/processes/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class EmbedderConfig(BaseModel):
embedding_azure_api_version: Optional[str] = Field(
description="Azure API version", default=None
)
embedding_openai_endpoint: Optional[str] = Field(
default=None,
description="Your custom OpenAI base url, "
"e.g. `https://custom-openai-deployment.com/`",
)

def get_huggingface_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
from unstructured_ingest.embed.huggingface import (
Expand All @@ -66,7 +71,16 @@ def get_huggingface_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEnco
def get_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
from unstructured_ingest.embed.openai import OpenAIEmbeddingConfig, OpenAIEmbeddingEncoder

return OpenAIEmbeddingEncoder(config=OpenAIEmbeddingConfig.model_validate(embedding_kwargs))
config_kwargs = {
"api_key": self.embedding_api_key,
"base_url": self.embedding_openai_endpoint,
}
if model_name := self.embedding_model_name:
config_kwargs["model_name"] = model_name

return OpenAIEmbeddingEncoder(
config=OpenAIEmbeddingConfig.model_validate(config_kwargs)
)

def get_azure_openai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
from unstructured_ingest.embed.azure_openai import (
Expand Down