Skip to content

Commit e64ad34

Browse files
committed
todo: fix nil storage panic and race condition in ID generation
Address PR review feedback: - WithStorage now panics early if given nil storage - Replace length-based ID generation with atomic counter to prevent duplicate IDs under concurrent access Assisted-By: cagent
1 parent f0a20bf commit e64ad34

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

pkg/tools/builtin/todo.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"strings"
77
"sync"
8+
"sync/atomic"
89

910
"github.com/docker/cagent/pkg/concurrent"
1011
"github.com/docker/cagent/pkg/tools"
@@ -106,14 +107,19 @@ func (s *MemoryTodoStorage) Clear() {
106107
type TodoOption func(*TodoTool)
107108

108109
// WithStorage sets a custom storage implementation for the TodoTool.
110+
// The provided storage must not be nil.
109111
func WithStorage(storage TodoStorage) TodoOption {
112+
if storage == nil {
113+
panic("todo: storage must not be nil")
114+
}
110115
return func(t *TodoTool) {
111116
t.handler.storage = storage
112117
}
113118
}
114119

115120
type todoHandler struct {
116121
storage TodoStorage
122+
nextID atomic.Int64
117123
}
118124

119125
var NewSharedTodoTool = sync.OnceValue(func() *TodoTool { return NewTodoTool() })
@@ -152,7 +158,7 @@ This toolset is REQUIRED for maintaining task state and ensuring all steps are c
152158
}
153159

154160
func (h *todoHandler) createTodo(_ context.Context, params CreateTodoArgs) (*tools.ToolCallResult, error) {
155-
id := fmt.Sprintf("todo_%d", h.storage.Len()+1)
161+
id := fmt.Sprintf("todo_%d", h.nextID.Add(1))
156162
todo := Todo{
157163
ID: id,
158164
Description: params.Description,
@@ -167,11 +173,11 @@ func (h *todoHandler) createTodo(_ context.Context, params CreateTodoArgs) (*too
167173
}
168174

169175
func (h *todoHandler) createTodos(_ context.Context, params CreateTodosArgs) (*tools.ToolCallResult, error) {
170-
start := h.storage.Len()
176+
ids := make([]int64, len(params.Descriptions))
171177
for i, desc := range params.Descriptions {
172-
id := fmt.Sprintf("todo_%d", start+i+1)
178+
ids[i] = h.nextID.Add(1)
173179
h.storage.Add(Todo{
174-
ID: id,
180+
ID: fmt.Sprintf("todo_%d", ids[i]),
175181
Description: desc,
176182
Status: "pending",
177183
})
@@ -183,7 +189,7 @@ func (h *todoHandler) createTodos(_ context.Context, params CreateTodosArgs) (*t
183189
if i > 0 {
184190
output.WriteString(", ")
185191
}
186-
fmt.Fprintf(&output, "[todo_%d]", start+i+1)
192+
fmt.Fprintf(&output, "[todo_%d]", ids[i])
187193
}
188194

189195
return &tools.ToolCallResult{

pkg/tools/builtin/todo_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ func TestTodoTool_WithStorage(t *testing.T) {
255255
assert.Equal(t, "Test item", storage.All()[0].Description)
256256
}
257257

258+
func TestTodoTool_WithStorage_NilPanics(t *testing.T) {
259+
assert.Panics(t, func() {
260+
WithStorage(nil)
261+
})
262+
}
263+
258264
func TestTodoTool_OutputSchema(t *testing.T) {
259265
tool := NewTodoTool()
260266

0 commit comments

Comments
 (0)