Skip to content

Commit 0148864

Browse files
authored
[ENH] add semantic search infrastructure (#1122)
* add semantic search infrastructure * test the get_emeddings function * remove duplicate cassette * switch VectorType to migration_types * activate the vector extension * try to initialize extension * add new configs * fix pipeline_config_id * style * change env and update openapi
1 parent ba74bc9 commit 0148864

26 files changed

+2401
-1196
lines changed

.github/workflows/workflow.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,10 @@ jobs:
467467
docker compose exec -T \
468468
store-pgsql17 \
469469
psql -U postgres -c "create database test_db"
470+
471+
docker compose exec -T \
472+
store-pgsql17 \
473+
psql -U postgres -d test_db -c "CREATE EXTENSION IF NOT EXISTS vector;"
470474
-
471475
name: Initialize Compose Database
472476
run: |

store/.devcontainer/devcontainer.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
"ms-toolsai.jupyter",
3333
"ms-python.vscode-pylance",
3434
"ms-python.python",
35+
"RooVeterinaryInc.roo-cline",
36+
3537
]
3638

3739
// Use 'forwardPorts' to make a list of ports inside the container available locally.

store/.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ BEARERINFO_FUNC=neurostore.resources.auth.decode_token
1313
AUTH0_CLIENT_ID=YOUR_CLIENT_ID
1414
AUTH0_CLIENT_SECRET=YOUR_CLIENT_SECRET
1515
COMPOSE_AUTH0_CLIENT_ID=COMPOSE_CLIENT_ID
16+
OPENAI_API_KEY=YOUR_OPENAI_API_KEY
1617
V_HOST=localhost
1718
DEBUG=True

store/backend/cassettes/ingest_neurovault.yml

Lines changed: 0 additions & 1136 deletions
This file was deleted.
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
from typing import List, Any, Optional
3+
4+
import openai
5+
from tenacity import (
6+
retry,
7+
stop_after_attempt,
8+
wait_exponential,
9+
retry_if_exception_type,
10+
)
11+
12+
13+
# Retry up to 3 attempts with small backoff: 0.5s then up to 1.0s
14+
@retry(
15+
reraise=True,
16+
stop=stop_after_attempt(3),
17+
wait=wait_exponential(multiplier=0.5, max=1.0),
18+
retry=retry_if_exception_type(Exception),
19+
)
20+
def _call_openai_create(
21+
model: str,
22+
input_text: str,
23+
dimensions: Optional[int] = None,
24+
api_key: Optional[str] = None,
25+
) -> Any:
26+
"""
27+
Internal call wrapped with tenacity to perform the OpenAI embeddings request.
28+
29+
This prefers the modern SDK client (openai.OpenAI().embeddings.create) when available,
30+
and falls back to legacy surfaces (openai.Embedding.create or openai.embeddings.create).
31+
The `dimensions` argument is forwarded to the OpenAI API when provided.
32+
The optional `api_key` is used to construct the modern OpenAI client if supplied.
33+
"""
34+
# Try modern OpenAI client (openai>=1.x)
35+
try:
36+
if hasattr(openai, "OpenAI"):
37+
client = openai.OpenAI(api_key=api_key) if api_key else openai.OpenAI()
38+
if dimensions is None:
39+
return client.embeddings.create(model=model, input=input_text)
40+
return client.embeddings.create(
41+
model=model, input=input_text, dimensions=dimensions
42+
)
43+
except Exception:
44+
# If the modern client fails for any reason, fall through to legacy surfaces.
45+
pass
46+
47+
# Fallback: older SDK surfaces
48+
if hasattr(openai, "Embedding") and hasattr(openai.Embedding, "create"):
49+
if dimensions is None:
50+
return openai.Embedding.create(model=model, input=input_text)
51+
return openai.Embedding.create(
52+
model=model, input=input_text, dimensions=dimensions
53+
)
54+
55+
if hasattr(openai, "embeddings") and hasattr(openai.embeddings, "create"):
56+
if dimensions is None:
57+
return openai.embeddings.create(model=model, input=input_text)
58+
return openai.embeddings.create(
59+
model=model, input=input_text, dimensions=dimensions
60+
)
61+
62+
raise AttributeError("No supported OpenAI embeddings API found on openai package")
63+
64+
65+
def get_embedding(text: str, dimensions: Optional[int] = None) -> List[float]:
66+
"""
67+
Obtain an embedding vector for the provided text using the OpenAI embeddings API.
68+
69+
Behavior:
70+
- Reads OPENAI_API_KEY from environment (raises RuntimeError if missing).
71+
- Uses the openai Python package to request embeddings. If `dimension` is provided
72+
it will be passed through to the OpenAI API via the `dimensions` parameter.
73+
- Retries up to 3 attempts with small backoff via tenacity on transient/network errors.
74+
- Validates the response shape and returns a plain list[float].
75+
76+
Args:
77+
text: Input text to embed.
78+
dimension: Optional requested embedding dimension (passed to OpenAI as `dimensions`).
79+
80+
Returns:
81+
List[float]: Embedding vector.
82+
83+
Raises:
84+
RuntimeError: If OPENAI_API_KEY is not set, request fails after retries,
85+
or the response is malformed.
86+
ValueError: If `dimension` is provided but not an int.
87+
"""
88+
api_key = os.getenv("OPENAI_API_KEY")
89+
if not api_key:
90+
raise RuntimeError("OPENAI_API_KEY not set")
91+
92+
# configure key for the openai client
93+
openai.api_key = api_key
94+
95+
model_name = "text-embedding-3-small"
96+
97+
if dimensions is not None and not isinstance(dimensions, int):
98+
raise ValueError("dimensions must be an int when provided")
99+
100+
try:
101+
# Use unified caller that handles both modern and legacy SDKs.
102+
resp = _call_openai_create(
103+
model=model_name, input_text=text, dimensions=dimensions, api_key=api_key
104+
)
105+
106+
# Handle both legacy dict responses and modern OpenAI response objects.
107+
data = None
108+
if isinstance(resp, dict):
109+
data = resp.get("data")
110+
elif hasattr(resp, "data"):
111+
data = resp.data
112+
else:
113+
raise RuntimeError(f"Unexpected response type from OpenAI: {type(resp)}")
114+
115+
if not isinstance(data, list) or not data:
116+
raise RuntimeError(f"Invalid response shape from OpenAI: {resp}")
117+
118+
first = data[0]
119+
# extract embedding whether first is a dict or an object with .embedding
120+
embedding = None
121+
if isinstance(first, dict):
122+
embedding = first.get("embedding")
123+
elif hasattr(first, "embedding"):
124+
embedding = first.embedding
125+
else:
126+
raise RuntimeError(f"Invalid embedding in OpenAI response: {resp}")
127+
128+
if not hasattr(embedding, "__iter__"):
129+
raise RuntimeError("Embedding returned by OpenAI is not iterable")
130+
131+
# normalize to plain list of floats
132+
embedding_list = list(embedding)
133+
return [float(x) for x in embedding_list]
134+
135+
except Exception as exc:
136+
# Surface a clear runtime error for callers/tests
137+
raise RuntimeError(f"Failed to get embedding: {exc}") from exc

0 commit comments

Comments
 (0)