Skip to content

Extend embeddings provider #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class EmbeddingsProvider(Enum):
HUGGINGFACE = "huggingface"
HUGGINGFACETEI = "huggingfacetei"
OLLAMA = "ollama"
GOOGLE = "google"
VOYAGE = "voyage"
SHUTTLEAI = "shuttleai"
COHERE = "cohere"


def get_env_variable(
Expand Down Expand Up @@ -171,7 +175,10 @@ async def dispatch(self, request, call_next):
).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")

GOOGLE_API_KEY = get_env_variable("GOOGLE_KEY", "")
VOYAGE_API_KEY = get_env_variable("VOYAGE_API_KEY", "")
SHUTTLEAI_KEY = get_env_variable("SHUTTLEAI_KEY", "") # use embeddings from shuttleai
COHERE_API_KEY = get_env_variable("COHERE_API_KEY", "")
## Embeddings


Expand All @@ -198,6 +205,31 @@ def init_embeddings(provider, model):
return HuggingFaceHubEmbeddings(model=model)
elif provider == EmbeddingsProvider.OLLAMA:
return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL)
elif provider == EmbeddingsProvider.GOOGLE:
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings

return GoogleGenerativeAIEmbeddings(
model=model,
api_key=GOOGLE_API_KEY,
)
elif provider == EmbeddingsProvider.VOYAGE:
from langchain_voyageai import VoyageAIEmbeddings

return VoyageAIEmbeddings(
model=model,
)
elif provider == EmbeddingsProvider.SHUTTLEAI:
return OpenAIEmbeddings(
model=model,
api_key=SHUTTLEAI_KEY,
openai_api_base="https://api.shuttleai.app/v1",
)
elif provider == EmbeddingsProvider.COHERE:
from langchain_cohere import CohereEmbeddings

return CohereEmbeddings(
model=model,
)
else:
raise ValueError(f"Unsupported embeddings provider: {provider}")

Expand All @@ -220,6 +252,15 @@ def init_embeddings(provider, model):
)
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.GOOGLE:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "models/embedding-001")
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.VOYAGE:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "voyage-large-2")
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.SHUTTLEAI:
# text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-large")
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.COHERE:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "embed-multilingual-v3.0")
else:
raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")

Expand Down
8 changes: 8 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: rag_api
channels:
- defaults
dependencies:
- python=3.11
- pip
- pip:
- -r requirements.lite.txt
4 changes: 4 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ async def embed_file(
)

try:
logger.info(f"Received file for embedding: filename={file.filename}, content_type={file.content_type}, file_id={file_id}")
loader, known_type, file_ext = get_loader(
file.filename, file.content_type, temp_file_path
)
Expand All @@ -403,13 +404,15 @@ async def embed_file(
if not result:
response_status = False
response_message = "Failed to process/store the file data."
logger.error(response_message, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process/store the file data.",
)
elif "error" in result:
response_status = False
response_message = "Failed to process/store the file data."
logger.error(response_message, exc_info=True)
if isinstance(result["error"], str):
response_message = result["error"]
else:
Expand All @@ -420,6 +423,7 @@ async def embed_file(
except Exception as e:
response_status = False
response_message = f"Error during file processing: {str(e)}"
logger.error(response_message, exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}",
Expand Down
66 changes: 41 additions & 25 deletions requirements.lite.txt
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
langchain==0.1.12
langchain_community==0.0.34
langchain_openai==0.0.8
langchain_core==0.1.45
sqlalchemy==2.0.28
python-dotenv==1.0.1
fastapi==0.110.0
# LangChain
langchain==0.2.12
langchain_community==0.2.11
langchain_openai==0.1.21
langchain_core==0.2.29
langchain-mongodb==0.1.8
langchain-voyageai==0.1.1
langchain-google-genai==1.0.8
langchain-cohere==0.2.1

# API
fastapi==0.112.0
uvicorn==0.30.5
python-multipart==0.0.9
aiofiles==24.1.0

# Database
sqlalchemy==2.0.32
psycopg2-binary==2.9.9
pgvector==0.2.5
uvicorn==0.28.0
pypdf==4.1.0
unstructured==0.12.6
markdown==3.6
networkx==3.2.1
pandas==2.2.1
openpyxl==3.1.2
pgvector==0.3.2
asyncpg==0.29.0
pymongo==4.8.0

# Data Processing and Analysis
pandas==2.2.2
openpyxl==3.1.5
networkx==3.3

# File Handling and Parsing
pypdf==4.3.1
unstructured==0.15.1
docx2txt==0.8
pypandoc==1.13
PyJWT==2.8.0
asyncpg==0.29.0
python-multipart==0.0.9
aiofiles==23.2.1
rapidocr-onnxruntime==1.3.17
opencv-python-headless==4.9.0.80
pymongo==4.6.3
langchain-mongodb==0.1.3
cryptography==42.0.7
python-magic==0.4.27
python-pptx==0.6.23
python-pptx==1.0.2

# Security
PyJWT==2.9.0
cryptography==43.0.0

# Miscellaneous
python-dotenv==1.0.1
markdown==3.6
rapidocr-onnxruntime==1.3.24
opencv-python-headless==4.10.0.84
68 changes: 42 additions & 26 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,30 +1,46 @@
langchain==0.1.12
langchain_community==0.0.34
langchain_openai==0.0.8
langchain_core==0.1.45
sqlalchemy==2.0.28
python-dotenv==1.0.1
fastapi==0.110.0
# LangChain
langchain==0.2.12
langchain_community==0.2.11
langchain_openai==0.1.21
langchain_core==0.2.29
langchain-mongodb==0.1.8
langchain-voyageai==0.1.1
langchain-google-genai==1.0.8
langchain-cohere==0.2.1

# API
fastapi==0.112.0
uvicorn==0.30.5
python-multipart==0.0.9
aiofiles==24.1.0

# Database
sqlalchemy==2.0.32
psycopg2-binary==2.9.9
pgvector==0.2.5
uvicorn==0.28.0
pypdf==4.1.0
unstructured==0.12.6
markdown==3.6
networkx==3.2.1
pandas==2.2.1
openpyxl==3.1.2
pgvector==0.3.2
asyncpg==0.29.0
pymongo==4.8.0

# Data Processing and Analysis
pandas==2.2.2
openpyxl==3.1.5
networkx==3.3

# File Handling and Parsing
pypdf==4.3.1
unstructured==0.15.1
docx2txt==0.8
pypandoc==1.13
PyJWT==2.8.0
asyncpg==0.29.0
python-multipart==0.0.9
sentence_transformers==2.5.1
aiofiles==23.2.1
rapidocr-onnxruntime==1.3.17
opencv-python-headless==4.9.0.80
pymongo==4.6.3
langchain-mongodb==0.1.3
cryptography==42.0.7
python-magic==0.4.27
python-pptx==0.6.23
python-pptx==1.0.2

# Security
PyJWT==2.9.0
cryptography==43.0.0

# Miscellaneous
python-dotenv==1.0.1
markdown==3.6
rapidocr-onnxruntime==1.3.24
opencv-python-headless==4.10.0.84
sentence_transformers==3.0.1