Skip to content

Commit 932ef58

Browse files
authored
Merge pull request #1971 from docker/memory-improvements
Add search, update, categories, and default path to memory tool
2 parents 4d1b9f0 + 213280d commit 932ef58

File tree

14 files changed

+529
-59
lines changed

14 files changed

+529
-59
lines changed

docs/configuration/tools/index.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,29 @@ toolsets:
101101

102102
### Memory
103103

104-
Persistent key-value storage backed by SQLite. Data survives across sessions, letting agents remember context, user preferences, and past decisions.
104+
Persistent key-value storage backed by SQLite. Data survives across sessions, letting agents remember context, user preferences, and past decisions. Memories can be organized with categories and searched by keyword.
105+
106+
Each agent gets its own database at `~/.cagent/memory/<agent-name>/memory.db` by default.
105107

106108
```yaml
107109
toolsets:
108110
- type: memory
109-
path: ./agent_memory.db # optional: custom database path
111+
path: ./agent_memory.db # optional: override the default location
110112
```
111113

112-
| Property | Type | Default | Description |
113-
| -------- | ------ | --------- | ---------------------------------------------------------------------- |
114-
| `path` | string | automatic | Path to the SQLite database file. If omitted, uses a default location. |
114+
| Property | Type | Default | Description |
115+
| -------- | ------ | -------------------------------------------- | ------------------------------------ |
116+
| `path` | string | `~/.cagent/memory/<agent-name>/memory.db` | Path to the SQLite database file |
117+
118+
| Operation | Description |
119+
| ------------------ | ------------------------------------------------------------------- |
120+
| `add_memory` | Store a new memory with optional category |
121+
| `get_memories` | Retrieve all stored memories |
122+
| `delete_memory` | Delete a specific memory by ID |
123+
| `search_memories` | Search memories by keywords and/or category (more efficient than get_all) |
124+
| `update_memory` | Update an existing memory's content and/or category by ID |
125+
126+
Memories support an optional `category` field (e.g., `preference`, `fact`, `project`, `decision`) for organization and filtering.
115127

116128
### Fetch
117129

pkg/acp/registry.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
func createToolsetRegistry(agent *Agent) *teamloader.ToolsetRegistry {
1515
registry := teamloader.NewDefaultToolsetRegistry()
1616

17-
registry.Register("filesystem", func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) {
17+
registry.Register("filesystem", func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) {
1818
wd := runConfig.WorkingDir
1919
if wd == "" {
2020
var err error

pkg/config/latest/validate.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ func (t *Toolset) validate() error {
116116
case "shell":
117117
// no additional validation needed
118118
case "memory":
119-
if t.Path == "" {
120-
return errors.New("memory toolset requires a path to be set")
121-
}
119+
// path is optional; defaults to ~/.cagent/memory/<agent-name>/memory.db
122120
case "tasks":
123121
// path defaults to ./tasks.json if not set
124122
case "mcp":

pkg/creator/agent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func createToolsetRegistry(workingDir string) *teamloader.ToolsetRegistry {
120120
}
121121

122122
registry := teamloader.NewDefaultToolsetRegistry()
123-
registry.Register("filesystem", func(context.Context, latest.Toolset, string, *config.RuntimeConfig) (tools.ToolSet, error) {
123+
registry.Register("filesystem", func(context.Context, latest.Toolset, string, *config.RuntimeConfig, string) (tools.ToolSet, error) {
124124
return tracker, nil
125125
})
126126

pkg/creator/agent_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func TestFileWriteTracker(t *testing.T) {
158158
require.NotNil(t, registry)
159159

160160
// Create the toolset through the registry
161-
toolset, err := registry.CreateTool(ctx, latest.Toolset{Type: "filesystem"}, runConfig.WorkingDir, runConfig)
161+
toolset, err := registry.CreateTool(ctx, latest.Toolset{Type: "filesystem"}, runConfig.WorkingDir, runConfig, "test-agent")
162162
require.NoError(t, err)
163163
require.NotNil(t, toolset)
164164

pkg/memory/database/database.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@ import (
55
"errors"
66
)
77

8-
var ErrEmptyID = errors.New("memory ID cannot be empty")
8+
var (
9+
ErrEmptyID = errors.New("memory ID cannot be empty")
10+
ErrMemoryNotFound = errors.New("memory not found")
11+
)
912

1013
type UserMemory struct {
11-
ID string `description:"The ID of the memory"`
12-
CreatedAt string `description:"The creation timestamp of the memory"`
13-
Memory string `description:"The content of the memory"`
14+
ID string `json:"id" description:"The ID of the memory"`
15+
CreatedAt string `json:"created_at" description:"The creation timestamp of the memory"`
16+
Memory string `json:"memory" description:"The content of the memory"`
17+
Category string `json:"category,omitempty" description:"The category of the memory"`
1418
}
1519

1620
type Database interface {
1721
AddMemory(ctx context.Context, memory UserMemory) error
1822
GetMemories(ctx context.Context) ([]UserMemory, error)
1923
DeleteMemory(ctx context.Context, memory UserMemory) error
24+
SearchMemories(ctx context.Context, query, category string) ([]UserMemory, error)
25+
UpdateMemory(ctx context.Context, memory UserMemory) error
2026
}

pkg/memory/database/sqlite/sqlite.go

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package sqlite
33
import (
44
"context"
55
"database/sql"
6+
"fmt"
7+
"strings"
68

79
"github.com/docker/cagent/pkg/memory/database"
810
"github.com/docker/cagent/pkg/sqliteutil"
@@ -26,20 +28,28 @@ func NewMemoryDatabase(path string) (database.Database, error) {
2628
return nil, err
2729
}
2830

31+
// Add category column if it doesn't exist (transparent migration)
32+
if _, err := db.ExecContext(context.Background(), "ALTER TABLE memories ADD COLUMN category TEXT DEFAULT ''"); err != nil {
33+
if !strings.Contains(err.Error(), "duplicate column name") {
34+
db.Close()
35+
return nil, fmt.Errorf("memory database migration failed: %w", err)
36+
}
37+
}
38+
2939
return &MemoryDatabase{db: db}, nil
3040
}
3141

3242
func (m *MemoryDatabase) AddMemory(ctx context.Context, memory database.UserMemory) error {
3343
if memory.ID == "" {
3444
return database.ErrEmptyID
3545
}
36-
_, err := m.db.ExecContext(ctx, "INSERT INTO memories (id, created_at, memory) VALUES (?, ?, ?)",
37-
memory.ID, memory.CreatedAt, memory.Memory)
46+
_, err := m.db.ExecContext(ctx, "INSERT INTO memories (id, created_at, memory, category) VALUES (?, ?, ?, ?)",
47+
memory.ID, memory.CreatedAt, memory.Memory, memory.Category)
3848
return err
3949
}
4050

4151
func (m *MemoryDatabase) GetMemories(ctx context.Context) ([]database.UserMemory, error) {
42-
rows, err := m.db.QueryContext(ctx, "SELECT id, created_at, memory FROM memories")
52+
rows, err := m.db.QueryContext(ctx, "SELECT id, created_at, memory, COALESCE(category, '') FROM memories")
4353
if err != nil {
4454
return nil, err
4555
}
@@ -48,7 +58,7 @@ func (m *MemoryDatabase) GetMemories(ctx context.Context) ([]database.UserMemory
4858
var memories []database.UserMemory
4959
for rows.Next() {
5060
var memory database.UserMemory
51-
err := rows.Scan(&memory.ID, &memory.CreatedAt, &memory.Memory)
61+
err := rows.Scan(&memory.ID, &memory.CreatedAt, &memory.Memory, &memory.Category)
5262
if err != nil {
5363
return nil, err
5464
}
@@ -66,3 +76,73 @@ func (m *MemoryDatabase) DeleteMemory(ctx context.Context, memory database.UserM
6676
_, err := m.db.ExecContext(ctx, "DELETE FROM memories WHERE id = ?", memory.ID)
6777
return err
6878
}
79+
80+
func (m *MemoryDatabase) SearchMemories(ctx context.Context, query, category string) ([]database.UserMemory, error) {
81+
var conditions []string
82+
var args []any
83+
84+
if query != "" {
85+
words := strings.Fields(query)
86+
for _, word := range words {
87+
conditions = append(conditions, "LOWER(memory) LIKE LOWER(?) ESCAPE '\\'")
88+
escaped := strings.ReplaceAll(word, `\`, `\\`)
89+
escaped = strings.ReplaceAll(escaped, `%`, `\%`)
90+
escaped = strings.ReplaceAll(escaped, `_`, `\_`)
91+
args = append(args, "%"+escaped+"%")
92+
}
93+
}
94+
95+
if category != "" {
96+
conditions = append(conditions, "LOWER(category) = LOWER(?)")
97+
args = append(args, category)
98+
}
99+
100+
stmt := "SELECT id, created_at, memory, COALESCE(category, '') FROM memories"
101+
if len(conditions) > 0 {
102+
stmt += " WHERE " + strings.Join(conditions, " AND ")
103+
}
104+
105+
rows, err := m.db.QueryContext(ctx, stmt, args...)
106+
if err != nil {
107+
return nil, err
108+
}
109+
defer rows.Close()
110+
111+
var memories []database.UserMemory
112+
for rows.Next() {
113+
var memory database.UserMemory
114+
err := rows.Scan(&memory.ID, &memory.CreatedAt, &memory.Memory, &memory.Category)
115+
if err != nil {
116+
return nil, err
117+
}
118+
memories = append(memories, memory)
119+
}
120+
121+
if err := rows.Err(); err != nil {
122+
return nil, err
123+
}
124+
125+
return memories, nil
126+
}
127+
128+
func (m *MemoryDatabase) UpdateMemory(ctx context.Context, memory database.UserMemory) error {
129+
if memory.ID == "" {
130+
return database.ErrEmptyID
131+
}
132+
133+
result, err := m.db.ExecContext(ctx, "UPDATE memories SET memory = ?, category = ? WHERE id = ?",
134+
memory.Memory, memory.Category, memory.ID)
135+
if err != nil {
136+
return err
137+
}
138+
139+
rows, err := result.RowsAffected()
140+
if err != nil {
141+
return err
142+
}
143+
if rows == 0 {
144+
return fmt.Errorf("%w: %s", database.ErrMemoryNotFound, memory.ID)
145+
}
146+
147+
return nil
148+
}

0 commit comments

Comments
 (0)