@@ -3,6 +3,8 @@ package sqlite
33import (
44 "context"
55 "database/sql"
6+ "fmt"
7+ "strings"
68
79 "github.com/docker/cagent/pkg/memory/database"
810 "github.com/docker/cagent/pkg/sqliteutil"
@@ -26,20 +28,28 @@ func NewMemoryDatabase(path string) (database.Database, error) {
2628 return nil , err
2729 }
2830
31+ // Add category column if it doesn't exist (transparent migration)
32+ if _ , err := db .ExecContext (context .Background (), "ALTER TABLE memories ADD COLUMN category TEXT DEFAULT ''" ); err != nil {
33+ if ! strings .Contains (err .Error (), "duplicate column name" ) {
34+ db .Close ()
35+ return nil , fmt .Errorf ("memory database migration failed: %w" , err )
36+ }
37+ }
38+
2939 return & MemoryDatabase {db : db }, nil
3040}
3141
3242func (m * MemoryDatabase ) AddMemory (ctx context.Context , memory database.UserMemory ) error {
3343 if memory .ID == "" {
3444 return database .ErrEmptyID
3545 }
36- _ , err := m .db .ExecContext (ctx , "INSERT INTO memories (id, created_at, memory) VALUES (?, ?, ?)" ,
37- memory .ID , memory .CreatedAt , memory .Memory )
46+ _ , err := m .db .ExecContext (ctx , "INSERT INTO memories (id, created_at, memory, category ) VALUES (?, ?, ?, ?)" ,
47+ memory .ID , memory .CreatedAt , memory .Memory , memory . Category )
3848 return err
3949}
4050
4151func (m * MemoryDatabase ) GetMemories (ctx context.Context ) ([]database.UserMemory , error ) {
42- rows , err := m .db .QueryContext (ctx , "SELECT id, created_at, memory FROM memories" )
52+ rows , err := m .db .QueryContext (ctx , "SELECT id, created_at, memory, COALESCE(category, '') FROM memories" )
4353 if err != nil {
4454 return nil , err
4555 }
@@ -48,7 +58,7 @@ func (m *MemoryDatabase) GetMemories(ctx context.Context) ([]database.UserMemory
4858 var memories []database.UserMemory
4959 for rows .Next () {
5060 var memory database.UserMemory
51- err := rows .Scan (& memory .ID , & memory .CreatedAt , & memory .Memory )
61+ err := rows .Scan (& memory .ID , & memory .CreatedAt , & memory .Memory , & memory . Category )
5262 if err != nil {
5363 return nil , err
5464 }
@@ -66,3 +76,73 @@ func (m *MemoryDatabase) DeleteMemory(ctx context.Context, memory database.UserM
6676 _ , err := m .db .ExecContext (ctx , "DELETE FROM memories WHERE id = ?" , memory .ID )
6777 return err
6878}
79+
80+ func (m * MemoryDatabase ) SearchMemories (ctx context.Context , query , category string ) ([]database.UserMemory , error ) {
81+ var conditions []string
82+ var args []any
83+
84+ if query != "" {
85+ words := strings .Fields (query )
86+ for _ , word := range words {
87+ conditions = append (conditions , "LOWER(memory) LIKE LOWER(?) ESCAPE '\\ '" )
88+ escaped := strings .ReplaceAll (word , `\` , `\\` )
89+ escaped = strings .ReplaceAll (escaped , `%` , `\%` )
90+ escaped = strings .ReplaceAll (escaped , `_` , `\_` )
91+ args = append (args , "%" + escaped + "%" )
92+ }
93+ }
94+
95+ if category != "" {
96+ conditions = append (conditions , "LOWER(category) = LOWER(?)" )
97+ args = append (args , category )
98+ }
99+
100+ stmt := "SELECT id, created_at, memory, COALESCE(category, '') FROM memories"
101+ if len (conditions ) > 0 {
102+ stmt += " WHERE " + strings .Join (conditions , " AND " )
103+ }
104+
105+ rows , err := m .db .QueryContext (ctx , stmt , args ... )
106+ if err != nil {
107+ return nil , err
108+ }
109+ defer rows .Close ()
110+
111+ var memories []database.UserMemory
112+ for rows .Next () {
113+ var memory database.UserMemory
114+ err := rows .Scan (& memory .ID , & memory .CreatedAt , & memory .Memory , & memory .Category )
115+ if err != nil {
116+ return nil , err
117+ }
118+ memories = append (memories , memory )
119+ }
120+
121+ if err := rows .Err (); err != nil {
122+ return nil , err
123+ }
124+
125+ return memories , nil
126+ }
127+
128+ func (m * MemoryDatabase ) UpdateMemory (ctx context.Context , memory database.UserMemory ) error {
129+ if memory .ID == "" {
130+ return database .ErrEmptyID
131+ }
132+
133+ result , err := m .db .ExecContext (ctx , "UPDATE memories SET memory = ?, category = ? WHERE id = ?" ,
134+ memory .Memory , memory .Category , memory .ID )
135+ if err != nil {
136+ return err
137+ }
138+
139+ rows , err := result .RowsAffected ()
140+ if err != nil {
141+ return err
142+ }
143+ if rows == 0 {
144+ return fmt .Errorf ("%w: %s" , database .ErrMemoryNotFound , memory .ID )
145+ }
146+
147+ return nil
148+ }
0 commit comments