Skip to content

Commit 8354850

Browse files
authored
fix: replace eval() with json.loads(), add embedding retry, use dynamic version (#23)
Signed-off-by: Cheney Zhang <chen.zhang@zilliz.com>
1 parent 8c365e1 commit 8354850

3 files changed

Lines changed: 25 additions & 9 deletions

File tree

src/vector_graph_rag/api/app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from fastapi.responses import FileResponse
2121
from pydantic import BaseModel, Field
2222

23-
from vector_graph_rag import VectorGraphRAG
23+
from vector_graph_rag import VectorGraphRAG, __version__
2424
from vector_graph_rag.config import Settings, get_settings
2525
from vector_graph_rag.storage.milvus import MilvusStore
2626
from vector_graph_rag.graph.graph import Graph
@@ -32,7 +32,7 @@
3232
class HealthResponse(BaseModel):
3333
"""Health check response."""
3434
status: str = Field(default="ok", description="Service status")
35-
version: str = Field(default="0.1.0", description="API version")
35+
version: str = Field(default=__version__, description="API version")
3636

3737

3838
class GraphInfo(BaseModel):
@@ -258,7 +258,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
258258
app = FastAPI(
259259
title="Vector Graph RAG API",
260260
description="Graph RAG using pure vector search with Milvus",
261-
version="0.1.0",
261+
version=__version__,
262262
)
263263

264264
# Add CORS middleware for frontend
@@ -300,7 +300,7 @@ def get_graph(graph_name: Optional[str] = None) -> Graph:
300300
@app.get("/health", response_model=HealthResponse, tags=["System"])
301301
async def health_check():
302302
"""Check if the service is running."""
303-
return HealthResponse(status="ok", version="0.1.0")
303+
return HealthResponse(status="ok", version=__version__)
304304

305305
@app.get("/graphs", response_model=ListGraphsResponse, tags=["System"])
306306
async def list_graphs():

src/vector_graph_rag/llm/extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,10 @@ def _load_tsv_cache(self, cache_file: str) -> None:
304304
query = row.get(query_col, '')
305305
triples_str = row.get('triples', '{}')
306306
try:
307-
triples_data = eval(triples_str) if isinstance(triples_str, str) else triples_str
307+
triples_data = json.loads(triples_str) if isinstance(triples_str, str) else triples_str
308308
if isinstance(triples_data, dict) and 'named_entities' in triples_data:
309309
self.ner_tsv_cache[query] = triples_data['named_entities']
310-
except:
310+
except (json.JSONDecodeError, KeyError, TypeError):
311311
pass
312312
print(f"Loaded {len(self.ner_tsv_cache)} NER entries from {cache_file}")
313313
except Exception as e:

src/vector_graph_rag/storage/embeddings.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,26 @@ def __init__(self, model_name: str, api_key: str, base_url: Optional[str] = None
156156

157157
self.model_name = model_name
158158
self.client = OpenAI(api_key=api_key, base_url=base_url)
159-
self._retry_decorator = retry(
159+
160+
# Wrap the API call with retry logic
161+
@retry(
160162
stop=stop_after_attempt(3),
161163
wait=wait_exponential(multiplier=1, min=2, max=10),
162164
)
165+
def _call_api(texts):
166+
return self.client.embeddings.create(model=self.model_name, input=texts)
167+
168+
self._call_api = _call_api
169+
170+
# Detect embedding dimension lazily
171+
self._dimension: Optional[int] = None
172+
173+
def _get_dimension(self) -> int:
174+
"""Get embedding dimension by making a test call."""
175+
if self._dimension is None:
176+
response = self._call_api(["test"])
177+
self._dimension = len(response.data[0].embedding)
178+
return self._dimension
163179

164180
def encode(
165181
self,
@@ -182,9 +198,9 @@ def encode(
182198
valid_texts = [texts[i] for i in valid_indices]
183199

184200
if not valid_texts:
185-
return np.zeros((len(texts), 1536))
201+
return np.zeros((len(texts), self._get_dimension()))
186202

187-
response = self.client.embeddings.create(model=self.model_name, input=valid_texts)
203+
response = self._call_api(valid_texts)
188204
sorted_data = sorted(response.data, key=lambda x: x.index)
189205
valid_embeddings = np.array([item.embedding for item in sorted_data])
190206

0 commit comments

Comments
 (0)