Skip to content

Commit 86df6ed

Browse files
authored
FROM feat/332-add-memories TO development (#337)
* Memories appera to work neew frontend toggle * Can toggle memory * Few fixes * Most of fixes g2g * Should be better than before * Mmoery working * Should fix frontend
1 parent 8158a3c commit 86df6ed

17 files changed

Lines changed: 293 additions & 111 deletions

File tree

Changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3434
- bugfix/28-cannot-auth-tools-list (2024-11-30)
3535

3636
### Changed
37+
- feat/332-add-memories (2025-06-15)
3738
- feat/290-api-as-a-tool-v2 (2025-06-08)
3839
- feat/327-agent-select-from-chatinput (2025-06-06)
3940
- feat/322-add-react-voice-viz (2025-05-29)

backend/src/routes/v0/llm/__init__.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,17 @@ async def new_thread(
9999
db: AsyncSession = Depends(get_async_db)
100100
):
101101
try:
102-
store = None
103-
if hasattr(body, 'memory') and body.memory:
104-
store = get_store_db()
105-
await store.setup()
106-
107-
controller = AgentController(db=db, user_id=user.id if user else None)
108-
output_type = request.headers.get("accept", "application/json")
109-
return await controller.query_thread(output_type=output_type, thread=body, store=store)
102+
async with get_store_db() as store:
103+
args = {
104+
"output_type": request.headers.get("accept", "application/json"),
105+
"thread": body
106+
}
107+
if hasattr(body, 'memory') and body.memory:
108+
await store.setup()
109+
args["store"] = store
110+
111+
controller = AgentController(db=db, user_id=user.id if user else None)
112+
return await controller.query_thread(**args)
110113
except httpx.HTTPStatusError as e:
111114
logger.error(f"Error creating new thread: {str(e)}")
112115
raise HTTPException(status_code=e.response.status_code , detail=str(e))
@@ -149,9 +152,17 @@ async def existing_thread(
149152
):
150153
try:
151154
async with get_store_db() as store:
155+
args = {
156+
"output_type": request.headers.get("accept", "application/json"),
157+
"thread": body,
158+
"thread_id": thread_id
159+
}
160+
if hasattr(body, 'memory') and body.memory:
161+
await store.setup()
162+
args["store"] = store
163+
152164
controller = AgentController(db=db, user_id=user.id if user else None)
153-
output_type = request.headers.get("accept", "application/json")
154-
return await controller.query_thread(output_type=output_type, thread=body, store=store)
165+
return await controller.query_thread(**args)
155166
except httpx.HTTPStatusError as e:
156167
logger.error(f"Error creating new thread: {str(e)}")
157168
raise HTTPException(status_code=e.response.status_code , detail=str(e))

backend/src/services/db.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from psycopg_pool import ConnectionPool
88
from src.constants import DB_URI, CONNECTION_POOL_KWARGS
99
from langgraph.store.postgres import AsyncPostgresStore
10+
from langchain.embeddings.base import init_embeddings
1011

1112
MAX_CONNECTION_POOL_SIZE = None
1213

@@ -33,7 +34,14 @@ def get_checkpoint_db() -> AsyncPostgresSaver:
3334
return AsyncPostgresSaver.from_conn_string(DB_URI)
3435

3536
def get_store_db() -> AsyncPostgresStore:
36-
return AsyncPostgresStore.from_conn_string(DB_URI)
37+
return AsyncPostgresStore.from_conn_string(
38+
conn_string=DB_URI,
39+
index={
40+
"dims": 1536,
41+
"embed": init_embeddings("openai:text-embedding-3-small"),
42+
"fields": ["memory"] # specify which fields to embed. Default is the whole serialized value
43+
}
44+
)
3745

3846
# Session context managers
3947
def get_db() -> Generator[SessionLocal, None, None]: # type: ignore

backend/src/tools/api.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -184,37 +184,30 @@ def make_api_call_func(
184184
data: Optional request body data for POST/PUT requests
185185
path_params: Optional dictionary of path parameters to format into the URL
186186
"""
187-
async def api_call(event: str, data: Dict[str, Any], config: RunnableConfig):
187+
async def api_call(event: str, data: Dict[str, Any] = {}):
188188
wrapper = GenericRequestsWrapper(headers=headers)
189189
# Format URL with path parameters if provided
190190
formatted_url = url.format(**(path_params or {}))
191-
async with get_store_db() as store:
192-
try:
193-
if method == "GET":
194-
response = await wrapper.aget(formatted_url)
195-
elif method == "POST":
196-
response = await wrapper.apost(formatted_url, data={'data': data, "event": event})
197-
elif method == "PUT":
198-
response = await wrapper.aput(formatted_url, data={'data': data, "event": event})
199-
elif method == "PATCH":
200-
response = await wrapper.apatch(formatted_url, data={'data': data, "event": event})
201-
elif method == "DELETE":
202-
response = await wrapper.adelete(formatted_url)
203-
else:
204-
raise ValueError(f"Unsupported HTTP method: {method}")
205-
206-
logger.debug(f"API Response for {method} {formatted_url}: {response}")
207-
if store and config.get("store", None):
208-
await store.aput(("memories", get_user_id(config)), event, {"memory": {
209-
"event": event,
210-
"data": data,
211-
"response": response
212-
}})
213-
return response
214-
215-
except Exception as e:
216-
logger.error(f"API call failed for {method} {formatted_url}: {str(e)}")
217-
raise
191+
try:
192+
if method == "GET":
193+
response = await wrapper.aget(formatted_url)
194+
elif method == "POST":
195+
response = await wrapper.apost(formatted_url, data={'data': data, "event": event})
196+
elif method == "PUT":
197+
response = await wrapper.aput(formatted_url, data={'data': data, "event": event})
198+
elif method == "PATCH":
199+
response = await wrapper.apatch(formatted_url, data={'data': data, "event": event})
200+
elif method == "DELETE":
201+
response = await wrapper.adelete(formatted_url)
202+
else:
203+
raise ValueError(f"Unsupported HTTP method: {method}")
204+
205+
logger.debug(f"API Response for {method} {formatted_url}: {response}")
206+
return response
207+
208+
except Exception as e:
209+
logger.error(f"API call failed for {method} {formatted_url}: {str(e)}")
210+
raise
218211

219212
api_call.__doc__ = description
220213
return api_call

backend/src/tools/memory.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import uuid
2+
from typing import List, TypedDict
3+
from langchain_core.tools import tool, ToolException
4+
from langchain_core.documents import Document
5+
from langchain_core.runnables import RunnableConfig
6+
from src.services.db import get_store_db
7+
from src.utils.tools import get_user_id
8+
from src.utils.logger import logger
9+
10+
class Memory(TypedDict):
11+
memory: str
12+
ttl: int = 10_080 # 1 week
13+
14+
@tool
15+
async def save_recall_memory(
16+
memory: str,
17+
ttl: int = 1440,
18+
config: RunnableConfig = None
19+
) -> str:
20+
"""Save memory to vectorstore for later semantic retrieval."""
21+
memory_id = str(uuid.uuid4())
22+
async with get_store_db() as store:
23+
await store.aput(("memories", get_user_id(config)), memory_id, {"memory": memory}, ttl=ttl)
24+
return f"Memory ID {memory_id} saved."
25+
26+
@tool
27+
async def delete_recall_memory(memory_id: str, config: RunnableConfig) -> str:
28+
"""Delete a specific memory for the current thread."""
29+
30+
async with get_store_db() as store:
31+
await store.adelete(("memories", get_user_id(config)), memory_id)
32+
return f"Memory ID {memory_id} deleted."
33+
34+
## TODO: Not sure if this works correctly, does not appear that way.
35+
@tool()
36+
async def search_recall_memories(query: str, config: RunnableConfig) -> List[str]:
37+
"""
38+
Search for relevant memories.
39+
40+
Args:
41+
query: The question to search for.
42+
43+
Returns:
44+
A list of memories that are relevant to the query.
45+
"""
46+
try:
47+
user_id = get_user_id(config)
48+
49+
async with get_store_db() as store:
50+
documents = await store.asearch(
51+
("memories", get_user_id(config)),
52+
query=query,
53+
limit=3,
54+
filter={"user_id": user_id}
55+
)
56+
print(documents)
57+
return [doc.dict() for doc in documents]
58+
except ToolException as e:
59+
logger.exception(f"Error searching for memories: {e}")
60+
raise ToolException(f"Error searching for memories: {str(e)}")

backend/src/utils/agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import json
32
from fastapi import HTTPException, status
43
from fastapi.responses import Response, JSONResponse, StreamingResponse
54
from langchain_core.messages import AnyMessage
@@ -13,7 +12,8 @@
1312
from src.schemas.entities.a2a import A2AServer
1413
from src.repos.user_repo import UserRepo
1514
from src.constants import APP_LOG_LEVEL
16-
from src.tools import dynamic_tools, init_tools
15+
from src.tools import init_tools
16+
from src.tools.memory import search_recall_memories, save_recall_memory, delete_recall_memory
1717
from src.utils.llm import LLMWrapper
1818
from src.constants.llm import ModelName
1919
from src.schemas.entities import Answer, Thread, ArcadeConfig
@@ -125,13 +125,13 @@ async def abuilder(
125125
)
126126
self.llm = LLMWrapper(model_name=model_name, tools=self.tools, user_repo=self.user_repo)
127127
if self.store:
128-
self.store.embeddings = LLMWrapper(model_name="openai:text-embedding-3-large").embedding_model()
128+
self.tools.extend([search_recall_memories, save_recall_memory, delete_recall_memory])
129129
memories = await self.store.asearch(("memories", self.user_id))
130130
if memories:
131-
system += "\n\nMemories:\n"
131+
system += "\n\n"
132132
for memory in memories:
133133
memory_dict = str(memory.dict())
134-
system += f"{memory_dict}\n"
134+
system += f"<recall_memory>{memory_dict}</recall_memory>\n"
135135

136136
if self.tools:
137137
graph = create_react_agent(self.llm, prompt=system, tools=self.tools, checkpointer=self.checkpointer, store=self.store)

backend/src/utils/tools.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,11 @@ def get_user_id(config: RunnableConfig) -> str:
5858
if user_id is None:
5959
raise ValueError("User ID needs to be provided to save a memory.")
6060

61-
return user_id
61+
return user_id
62+
63+
def get_thread_id(config: RunnableConfig) -> str:
64+
thread_id = config["configurable"].get("thread_id")
65+
if thread_id is None:
66+
raise ValueError("Thread ID needs to be provided to save a memory.")
67+
68+
return thread_id

frontend/src/components/drawers/ConfigDrawer/tabs/tab-content-info.tsx

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,58 @@ import { Button } from "@/components/ui/button"
22
import SelectModel from "@/components/selects/SelectModel"
33
import { useChatContext } from "@/context/ChatContext"
44
import SystemMessageCard from "@/components/cards/SystemMessageCard"
5+
import { Switch } from "@/components/ui/switch"
6+
import { useMemory } from "@/hooks/useMemory"
7+
import { useSystem } from "@/hooks/useSystem"
8+
import { useState } from "react"
59

610
function TabContentInfo() {
7-
const { payload } = useChatContext()
11+
const { payload, setPayload } = useChatContext()
12+
const [completed, ] = useState(false)
13+
14+
const handleMemoryToggle = () => {
15+
setPayload((prev: any) => ({ ...prev, memory: !prev.memory }));
16+
}
17+
18+
const handleSaveSystemPrompt = () => {
19+
setPayload((prev: any) => ({ ...prev, system: payload.system }));
20+
}
21+
822
return (
923
<div className="space-y-6">
1024
<div className="space-y-4">
1125
<h3 className="text-lg font-medium">Model</h3>
1226
<SelectModel />
1327
</div>
1428

29+
<div className="space-y-4">
30+
<h3 className="text-lg font-medium">Memory</h3>
31+
<div className="flex items-center justify-between">
32+
<div className="space-y-1">
33+
<p className="text-sm font-medium">Enable Memory</p>
34+
<p className="text-xs text-muted-foreground">
35+
Allow the AI to remember previous conversations
36+
</p>
37+
</div>
38+
<Switch
39+
checked={useMemory()}
40+
onCheckedChange={handleMemoryToggle}
41+
/>
42+
</div>
43+
</div>
44+
1545
<div className="space-y-4">
1646
<h3 className="text-lg font-medium">System Prompt</h3>
17-
<SystemMessageCard content={payload.system} />
18-
<Button className="mt-4 w-full">Save</Button>
47+
<SystemMessageCard content={useSystem()} />
48+
<Button
49+
className="mt-4 w-full"
50+
onClick={() => {
51+
handleSaveSystemPrompt();
52+
alert("System prompt saved");
53+
}}
54+
>
55+
{completed ? "Saved!" : "Save"}
56+
</Button>
1957
</div>
2058
</div>
2159
)

frontend/src/components/inputs/ChatInput.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ export default function ChatInput() {
126126
onDrop={handleDrop}
127127
onDragOver={(e) => e.preventDefault()}
128128
onKeyDown={(e) => {
129-
if (e.key === "Enter" && !e.shiftKey) {
129+
if (e.key === "Enter" && !e.shiftKey && !isRecording && payload.query.length > 0) {
130130
e.preventDefault()
131131
handleSubmit()
132132
}

frontend/src/components/selects/SelectModel/SelectModel.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ import { SiAnthropic, SiOpenai, SiOllama, SiGoogle } from 'react-icons/si';
66
import GroqIcon from "@/components/icons/GroqIcon";
77
import { useQueryParam, StringParam, withDefault } from 'use-query-params';
88
import { useChatContext } from "@/context/ChatContext";
9-
import { DEFAULT_CHAT_MODEL } from "@/lib/config/llm";
9+
import { useModel } from "@/hooks/useModel";
10+
1011

11-
// Create a parameter with default value
12-
const ModelParam = withDefault(StringParam, DEFAULT_CHAT_MODEL);
1312

1413
function SelectModel() {
14+
// Create a parameter with default value
1515
const { setPayload, models } = useChatContext();
16-
const [model, setModel] = useQueryParam('model', ModelParam);
16+
const [model, setModel] = useQueryParam('model', withDefault(StringParam, useModel()));
1717

1818
const handleModelChange = (modelId: string) => {
1919
setModel(modelId);
@@ -24,7 +24,7 @@ function SelectModel() {
2424
if (model) {
2525
setPayload((prev: any) => ({ ...prev, model }));
2626
}
27-
}, [model, setPayload]);
27+
}, [model]);
2828

2929
return (
3030
<Select value={model} onValueChange={handleModelChange}>

0 commit comments

Comments
 (0)