|
10 | 10 |
|
11 | 11 | from fastapi import Body, File, Form, HTTPException, UploadFile |
12 | 12 | from langchain.text_splitter import RecursiveCharacterTextSplitter |
13 | | -from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceHubEmbeddings, OpenAIEmbeddings |
| 13 | +from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceInferenceAPIEmbeddings, OpenAIEmbeddings |
14 | 14 | from langchain_core.documents import Document |
15 | 15 | from langchain_milvus.vectorstores import Milvus |
16 | 16 | from langchain_text_splitters import HTMLHeaderTextSplitter |
|
36 | 36 | # Local Embedding model |
37 | 37 | LOCAL_EMBEDDING_MODEL = os.getenv("LOCAL_EMBEDDING_MODEL", "maidalun1020/bce-embedding-base_v1") |
38 | 38 | # TEI configuration |
39 | | -TEI_EMBEDDING_MODEL = os.environ.get("TEI_EMBEDDING_MODEL", "/home/user/bge-large-zh-v1.5") |
| 39 | +EMBED_MODEL = os.environ.get("EMBED_MODEL", "BAAI/bge-base-en-v1.5") |
40 | 40 | TEI_EMBEDDING_ENDPOINT = os.environ.get("TEI_EMBEDDING_ENDPOINT", "") |
| 41 | +# Huggingface API token for TEI embedding endpoint |
| 42 | +HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "") |
| 43 | + |
41 | 44 | # MILVUS configuration |
42 | 45 | MILVUS_HOST = os.getenv("MILVUS_HOST", "localhost") |
43 | 46 | MILVUS_PORT = int(os.getenv("MILVUS_PORT", 19530)) |
@@ -75,7 +78,7 @@ def ingest_chunks_to_milvus(embeddings, file_name: str, chunks: List): |
75 | 78 | except Exception as e: |
76 | 79 | if logflag: |
77 | 80 | logger.info(f"[ ingest chunks ] fail to ingest chunks into Milvus. error: {e}") |
78 | | - raise HTTPException(status_code=500, detail=f"Fail to store chunks of file {file_name}.") |
| 81 | + raise HTTPException(status_code=500, detail=f"Fail to store chunks of file {file_name}: {e}") |
79 | 82 |
|
80 | 83 | if logflag: |
81 | 84 | logger.info(f"[ ingest chunks ] Docs ingested file {file_name} to Milvus collection {COLLECTION_NAME}.") |
@@ -189,7 +192,23 @@ def _initialize_embedder(self): |
189 | 192 | # create embeddings using TEI endpoint service |
190 | 193 | if logflag: |
191 | 194 | logger.info(f"[ milvus embedding ] TEI_EMBEDDING_ENDPOINT:{TEI_EMBEDDING_ENDPOINT}") |
192 | | - embeddings = HuggingFaceHubEmbeddings(model=TEI_EMBEDDING_ENDPOINT) |
| 195 | + if not HUGGINGFACEHUB_API_TOKEN: |
| 196 | + raise HTTPException( |
| 197 | + status_code=400, |
| 198 | + detail="You MUST offer the `HUGGINGFACEHUB_API_TOKEN` when using `TEI_EMBEDDING_ENDPOINT`.", |
| 199 | + ) |
| 200 | + import requests |
| 201 | + |
| 202 | + response = requests.get(TEI_EMBEDDING_ENDPOINT + "/info") |
| 203 | + if response.status_code != 200: |
| 204 | + raise HTTPException( |
| 205 | + status_code=400, detail=f"TEI embedding endpoint {TEI_EMBEDDING_ENDPOINT} is not available." |
| 206 | + ) |
| 207 | + model_id = response.json()["model_id"] |
| 208 | + # create embeddings using TEI endpoint service |
| 209 | + embeddings = HuggingFaceInferenceAPIEmbeddings( |
| 210 | + api_key=HUGGINGFACEHUB_API_TOKEN, model_name=model_id, api_url=TEI_EMBEDDING_ENDPOINT |
| 211 | + ) |
193 | 212 | else: |
194 | 213 | # create embeddings using local embedding model |
195 | 214 | if logflag: |
@@ -274,7 +293,7 @@ async def ingest_files( |
274 | 293 | search_res = search_by_file(my_milvus.col, encode_file) |
275 | 294 | except Exception as e: |
276 | 295 | raise HTTPException( |
277 | | - status_code=500, detail=f"Failed when searching in Milvus db for file {file.filename}." |
| 296 | + status_code=500, detail=f"Failed when searching in Milvus db for file {file.filename}: {e}" |
278 | 297 | ) |
279 | 298 | if len(search_res) > 0: |
280 | 299 | if logflag: |
@@ -319,7 +338,7 @@ async def ingest_files( |
319 | 338 | search_res = search_by_file(my_milvus.col, encoded_link + ".txt") |
320 | 339 | except Exception as e: |
321 | 340 | raise HTTPException( |
322 | | - status_code=500, detail=f"Failed when searching in Milvus db for link {link}." |
| 341 | + status_code=500, detail=f"Failed when searching in Milvus db for link {link}: {e}" |
323 | 342 | ) |
324 | 343 | if len(search_res) > 0: |
325 | 344 | if logflag: |
@@ -375,7 +394,7 @@ async def get_files(self): |
375 | 394 | try: |
376 | 395 | all_data = search_all(my_milvus.col) |
377 | 396 | except Exception as e: |
378 | | - raise HTTPException(status_code=500, detail="Failed when searching in Milvus db for all files.") |
| 397 | + raise HTTPException(status_code=500, detail=f"Failed when searching in Milvus db for all files: {e}") |
379 | 398 |
|
380 | 399 | # return [] if no data in db |
381 | 400 | if len(all_data) == 0: |
@@ -422,8 +441,7 @@ async def delete_files(self, file_path: str = Body(..., embed=True)): |
422 | 441 | except Exception as e: |
423 | 442 | if logflag: |
424 | 443 | logger.info(f"[ milvus delete ] {e}. Fail to delete {upload_folder}.") |
425 | | - raise HTTPException(status_code=500, detail=f"Fail to delete {upload_folder}.") |
426 | | - |
| 444 | + raise HTTPException(status_code=500, detail=f"Fail to delete {upload_folder}: {e}") |
427 | 445 | if logflag: |
428 | 446 | logger.info("[ milvus delete ] successfully delete all files.") |
429 | 447 |
|
|
0 commit comments