Skip to content

Commit d01482f

Browse files
authored
FROM feat/530-semantic-search-over-threads TO development (#531)
* init * Appears to have one-shotted, few pieces to fix but got for most part * We have full semantic but the content returned is the docs and we need to clean this up. * Switch to name thread_repo * Consolidates thread_repo, had wrong init as well * Semantic search for threads * fix build * GPT 4.1 default model
1 parent ac8cc9c commit d01482f

22 files changed

Lines changed: 601 additions & 101 deletions

File tree

Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
- bug/370-anthropic-streaming (2025-09-16)
2323

2424
### Changed
25+
- feat/530-semantic-search-over-threads (2025-11-30)
2526
- feat/534-mermaid-diagram (2025-11-30)
2627
- feat/528-pagination-for-threads (2025-11-30)
2728
- feat/519-react-window (2025-11-23)

backend/src/constants/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,7 @@ def values(cls) -> list[str]:
8888

8989
# GridSite
9090
MICROSOFT_TEAMS_WEBHOOK_URL = os.getenv("MICROSOFT_TEAMS_WEBHOOK_URL")
91+
92+
# Thread Search
93+
# Number of recent messages to store per thread snapshot for semantic search
94+
THREAD_SNAPSHOT_MESSAGE_COUNT = 20

backend/src/constants/examples/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,26 @@ class Examples:
411411
},
412412
),
413413
}
414+
THREAD_SEMANTIC_SEARCH_EXAMPLES = {
415+
"semantic_search": Example(
416+
summary="semantic_search",
417+
description="Search threads using natural language",
418+
value={
419+
"query": "threads about database optimization",
420+
"limit": 10,
421+
"assistant_id": None,
422+
},
423+
),
424+
"semantic_search_with_assistant": Example(
425+
summary="semantic_search_with_assistant",
426+
description="Search threads for a specific assistant",
427+
value={
428+
"query": "conversations about authentication",
429+
"limit": 5,
430+
"assistant_id": "assistant-uuid-here",
431+
},
432+
),
433+
}
414434

415435
LLM_INVOKE_EXAMPLES = {
416436
"stateless_invoke": Example(

backend/src/constants/llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ChatModels(str, Enum):
1717
OPENAI_REASONING_03 = "openai:o3"
1818
OPENAI_REASONING_04_MINI = "openai:o4-mini"
1919
OPENAI_GPT_4_1_NANO = "openai:gpt-4.1-nano"
20+
OPENAI_GPT_4_1_MINI = "openai:gpt-4.1-mini"
2021
OPENAI_GPT_5_NANO = "openai:gpt-5-nano"
2122
OPENAI_GPT_5_MINI = "openai:gpt-5-mini"
2223
OPENAI_GPT_5 = "openai:gpt-5"
@@ -84,6 +85,7 @@ def get_free_models():
8485
models = []
8586
if OPENAI_API_KEY:
8687
models.append(ChatModels.OPENAI_GPT_5_NANO.value)
88+
models.append(ChatModels.OPENAI_GPT_4_1_MINI.value)
8789
if ANTHROPIC_API_KEY:
8890
models.append(ChatModels.ANTHROPIC_CLAUDE_4_5_HAIKU.value)
8991
if GOOGLE_API_KEY:

backend/src/flows/xml_agent.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from langgraph.checkpoint.base import BaseCheckpointSaver
2222
from langgraph.store.base import BaseStore
2323

24+
from src.utils.format import format_content
25+
2426

2527
###########################################
2628
## Parser
@@ -31,8 +33,9 @@ def input_parser(
3133
xml_lines = ["<thread>"]
3234
for message in messages:
3335
if isinstance(message, HumanMessage):
36+
content = format_content(message.content)
3437
xml_lines.append(
35-
f' <event id="{message.id}" type="{message.type}">{message.content}</event>'
38+
f' <event id="{message.id}" type="{message.type}">{content}</event>'
3639
)
3740
elif isinstance(message, ToolMessage):
3841
xml_lines.append(
@@ -45,8 +48,9 @@ def input_parser(
4548
f' <event id="{tool_call["id"]}" type="tool_input" name="{tool_call["name"]}">{json.dumps(tool_call["args"])}</event>'
4649
)
4750
else:
51+
content = format_content(message.content)
4852
xml_lines.append(
49-
f' <event id="{message.id}" type="{message.type}">{message.content}</event>'
53+
f' <event id="{message.id}" type="{message.type}">{content}</event>'
5054
)
5155
xml_lines.append("</thread>")
5256
return "\n".join(xml_lines) + llm_response_prefix

backend/src/repos/base_repo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _get_namespace(self):
2424
return (self.user_id, self.entity_type)
2525

2626
async def _set(
27-
self, key: str, value: Source | Project | Document, ttl: int | None = None
27+
self, key: str, value: Any, ttl: int | None = None
2828
) -> bool:
2929
await self.store.aput(
3030
namespace=self._get_namespace(),

backend/src/repos/thread_repo.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import asyncio
2+
from langgraph.store.base import BaseStore, SearchItem
3+
from langgraph.store.memory import InMemoryStore
4+
from langgraph.store.postgres.aio import AsyncPostgresStore
5+
6+
from src.services.db import get_store_in_memory
7+
from src.schemas.entities import SearchFilter
8+
from src.constants import THREAD_SNAPSHOT_MESSAGE_COUNT
9+
from src.repos.base_repo import BaseRepo
10+
from src.schemas.entities.store import ThreadSnapshot
11+
from src.utils.logger import logger
12+
from src.utils.format import format_xml_thread
13+
from src.utils.messages import from_message_to_dict
14+
15+
16+
FIELDS = ["messages", "files"]
17+
18+
class ThreadRepo(BaseRepo):
19+
def __init__(self, user_id: str, store: BaseStore = get_store_in_memory(fields=FIELDS)):
20+
## Add fields to the store (if supported)
21+
self.user_id = user_id
22+
self.store: BaseStore = store
23+
24+
try:
25+
self.store.fields = FIELDS
26+
except AttributeError:
27+
pass
28+
super().__init__(user_id=user_id, store=store, entity_type="threads")
29+
30+
31+
async def search(
32+
self,
33+
search_filter: SearchFilter,
34+
) -> list[dict]:
35+
try:
36+
max_retries = 3
37+
retry_delay = 1 # seconds
38+
39+
for attempt in range(max_retries):
40+
try:
41+
async with self.store as store:
42+
if search_filter.query:
43+
queried_threads: list[SearchItem] = await store.asearch(
44+
self._get_namespace(),
45+
limit=search_filter.limit,
46+
filter=search_filter.filter,
47+
query=search_filter.query,
48+
)
49+
return [
50+
ThreadSnapshot(
51+
id=thread.key,
52+
messages=thread.value["messages"],
53+
files=thread.value["files"],
54+
score=thread.score,
55+
updated_at=thread.updated_at
56+
).model_dump(exclude_none=True) for thread in queried_threads
57+
]
58+
threads = await store.asearch(
59+
self._get_namespace(),
60+
limit=search_filter.limit,
61+
filter=search_filter.filter,
62+
)
63+
return sorted(
64+
[thread.dict() for thread in threads],
65+
key=lambda x: x.get("updated_at"),
66+
reverse=True,
67+
)
68+
except Exception as e:
69+
error_msg = str(e).lower()
70+
if "connection" in error_msg and "closed" in error_msg:
71+
logger.warning(
72+
f"Store connection closed on attempt {attempt + 1}/{max_retries}: {e}"
73+
)
74+
if attempt < max_retries - 1:
75+
await asyncio.sleep(
76+
retry_delay * (2**attempt)
77+
) # Exponential backoff
78+
continue
79+
raise e
80+
except Exception as e:
81+
logger.error(f"Error searching threads: {e}")
82+
return []
83+
84+
async def update(self, thread_id: str, data: dict):
85+
86+
# Extract last human message for storage
87+
messages = data.get("messages", [])
88+
messages = from_message_to_dict(messages, include_tool_calls=False)
89+
recent_messages = (
90+
messages[-THREAD_SNAPSHOT_MESSAGE_COUNT:]
91+
if len(messages) > THREAD_SNAPSHOT_MESSAGE_COUNT
92+
else messages
93+
)
94+
95+
data["messages"] = recent_messages
96+
97+
await self.store.aput(
98+
namespace=self._get_namespace(), key=thread_id, value=data
99+
)
100+
101+
return True
102+
103+
104+
async def get(self, thread_id: str) -> dict:
105+
return await self._get(thread_id)
106+
107+
async def delete(self, thread_id: str) -> bool:
108+
try:
109+
await self._delete(thread_id)
110+
logger.info(f"Thread {thread_id} deleted successfully")
111+
return True
112+
except Exception as e:
113+
logger.error(f"Error deleting thread: {e}")
114+
return False
115+
116+
async def _upsert_snapshot(self, thread_id: str, messages: list) -> bool:
117+
"""Create or update a thread snapshot with recent messages.
118+
119+
Note: messages should already be filtered to recent messages before calling this method.
120+
"""
121+
try:
122+
# Extract recent messages for snapshot (last N messages)
123+
recent_messages = (
124+
messages[-THREAD_SNAPSHOT_MESSAGE_COUNT:]
125+
if len(messages) > THREAD_SNAPSHOT_MESSAGE_COUNT
126+
else messages
127+
)
128+
129+
# Format messages as "Role: content" pairs
130+
page_content = format_xml_thread(recent_messages, include_tool_calls=False)
131+
132+
# Create snapshot with metadata
133+
snapshot = ThreadSnapshot(
134+
thread_id=thread_id,
135+
page_content=page_content,
136+
metadata={
137+
"thread_id": thread_id,
138+
"message_count": len(messages),
139+
}
140+
)
141+
142+
await self._set(thread_id, snapshot)
143+
return True
144+
145+
except Exception as e:
146+
logger.error(f"Failed to upsert thread snapshot for {thread_id}: {e}")
147+
return False
148+

backend/src/routes/v0/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ async def list_models():
232232
return JSONResponse(
233233
status_code=status.HTTP_200_OK,
234234
content={
235-
"default": ChatModels.OPENAI_GPT_5_NANO.value,
235+
"default": ChatModels.OPENAI_GPT_4_1_MINI.value,
236236
"free": get_free_models(),
237237
"models": get_all_models(),
238238
},

backend/src/routes/v0/thread.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from fastapi.responses import Response
44
from langgraph.store.base import BaseStore
55
from src.contexts.service import ServiceContext
6-
from src.schemas.entities import ThreadSearch
6+
from src.schemas.entities import SearchFilter, ThreadSemanticSearchRequest
77
from src.utils.logger import logger
88
from src.constants.examples import Examples
99
from src.schemas.models import ProtectedUser
@@ -16,27 +16,26 @@
1616

1717
@router.post("/threads/search", name="Query Threads in Checkpointer")
1818
async def search_threads(
19-
thread_search: ThreadSearch = Body(
19+
search_filter: SearchFilter = Body(
2020
openapi_examples=Examples.THREAD_SEARCH_EXAMPLES
2121
),
2222
user: ProtectedUser = Depends(verify_credentials),
2323
store: AsyncPostgresStore = Depends(get_store),
2424
):
2525
try:
26-
filter = thread_search.model_dump(exclude_none=True).get("filter", {})
2726
async with get_checkpoint_db() as checkpointer:
2827
service_context = ServiceContext(
2928
user_id=user.id, store=store, checkpointer=checkpointer
3029
)
31-
if "thread_id" in filter and not "checkpoint_id" in filter:
30+
if "thread_id" in search_filter.filter and not "checkpoint_id" in search_filter.filter:
3231
checkpoints = await service_context.checkpoint_service.list_checkpoints(
33-
thread_id=filter["thread_id"]
32+
thread_id=search_filter.filter["thread_id"]
3433
)
3534
# if not checkpoints:
3635
# raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Checkpoints not found")
3736
return {"checkpoints": checkpoints}
3837

39-
threads = await service_context.thread_service.search(filter=filter)
38+
threads = await service_context.thread_service.search(search_filter)
4039
return {"threads": threads}
4140
except Exception as e:
4241
logger.exception(f"Error searching threads: {e}")
@@ -118,3 +117,63 @@ async def update_thread(
118117
raise HTTPException(
119118
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
120119
)
120+
121+
122+
@router.post("/threads/search/semantic", name="Semantic Search Over Threads")
123+
async def semantic_search_threads(
124+
request: ThreadSemanticSearchRequest = Body(
125+
openapi_examples=Examples.THREAD_SEMANTIC_SEARCH_EXAMPLES
126+
),
127+
user: ProtectedUser = Depends(verify_credentials),
128+
store: AsyncPostgresStore = Depends(get_store),
129+
):
130+
try:
131+
# Validate query is not empty
132+
if not request.query or request.query.strip() == "":
133+
raise HTTPException(
134+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
135+
detail="query field is required and must not be empty"
136+
)
137+
138+
async with get_checkpoint_db() as checkpointer:
139+
service_context = ServiceContext(
140+
user_id=user.id, store=store, checkpointer=checkpointer
141+
)
142+
143+
# Perform semantic search
144+
search_results = await service_context.thread_service.thread_snapshot_repo.search(
145+
query=request.query,
146+
limit=request.limit,
147+
assistant_id=request.assistant_id
148+
)
149+
150+
# Enrich results with thread titles
151+
enriched_results = []
152+
for result in search_results:
153+
thread_id = result.get("thread_id")
154+
if thread_id:
155+
# Fetch thread data to get title
156+
thread_data = await service_context.thread_service.get(thread_id)
157+
title = "Untitled Thread"
158+
if thread_data and thread_data.value:
159+
# Try to get title from thread data, fallback to first message
160+
title = thread_data.value.get("title", title)
161+
162+
enriched_results.append({
163+
"thread_id": thread_id,
164+
"title": title,
165+
"excerpt": result.get("excerpt", ""),
166+
"score": result.get("score", 0.0),
167+
"updated_at": result.get("updated_at"),
168+
})
169+
170+
return {"results": enriched_results}
171+
172+
except HTTPException:
173+
raise
174+
except Exception as e:
175+
logger.exception(f"Error performing semantic search: {e}")
176+
raise HTTPException(
177+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
178+
detail=str(e)
179+
)

backend/src/schemas/entities/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import BaseModel, Field
66

77
from src.schemas.entities.llm import *
8+
from src.schemas.entities.store import ThreadSnapshot
89
from src.constants.examples import (
910
ADD_DOCUMENTS_EXAMPLE,
1011
THREAD_HISTORY_EXAMPLE,
@@ -134,3 +135,19 @@ class SearchFilter(BaseModel):
134135
"example": {"query": "", "filter": {}, "limit": 20, "offset": 0}
135136
}
136137
}
138+
139+
140+
class ThreadSemanticSearchRequest(BaseModel):
141+
query: str = Field(..., description="Natural language search query")
142+
limit: int = Field(default=10, description="Maximum number of results (max 50)")
143+
assistant_id: Optional[str] = Field(default=None, description="Optional assistant ID to filter results")
144+
145+
model_config = {
146+
"json_schema_extra": {
147+
"example": {
148+
"query": "threads about database optimization",
149+
"limit": 10,
150+
"assistant_id": None
151+
}
152+
}
153+
}

0 commit comments

Comments
 (0)