Skip to content

Commit 329e94e

Browse files
committed
Extract TaskStorage interface for tasks storage
Follow the same pattern used for todos in PR docker#1960: - Define a TaskStorage interface with domain-level methods (All, Get, Put, Delete) that map to what the toolset does - Extract file I/O into FileTaskStorage implementation - Add TaskOption functional options with WithTaskStorage - Move mutex from TasksTool into FileTaskStorage - Refactor all handlers to use the interface directly - Update NewTasksTool to accept variadic options - Add tests for WithTaskStorage and nil panic No in-memory storage implementation is needed for tasks. Assisted-By: cagent
1 parent 26f794c commit 329e94e

File tree

2 files changed

+168
-93
lines changed

2 files changed

+168
-93
lines changed

pkg/tools/builtin/tasks.go

Lines changed: 138 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,101 @@ type taskStore struct {
8787
Tasks map[string]Task `json:"tasks"`
8888
}
8989

90-
type TasksTool struct {
90+
// TaskStorage defines the storage layer for task items.
91+
type TaskStorage interface {
92+
// All returns every task keyed by ID.
93+
All() map[string]Task
94+
// Get returns a single task by ID and whether it was found.
95+
Get(id string) (Task, bool)
96+
// Put creates or updates a task.
97+
Put(task Task) error
98+
// Delete removes a task by ID.
99+
Delete(id string) error
100+
}
101+
102+
// FileTaskStorage is a file-backed, concurrency-safe implementation of TaskStorage.
103+
type FileTaskStorage struct {
91104
mu sync.Mutex
92105
filePath string
106+
}
107+
108+
// NewFileTaskStorage creates a FileTaskStorage that reads/writes to the given path.
109+
func NewFileTaskStorage(filePath string) *FileTaskStorage {
110+
return &FileTaskStorage{filePath: filePath}
111+
}
112+
113+
func (s *FileTaskStorage) load() map[string]Task {
114+
data, err := os.ReadFile(s.filePath)
115+
if err != nil {
116+
return make(map[string]Task)
117+
}
118+
var store taskStore
119+
if err := json.Unmarshal(data, &store); err != nil {
120+
return make(map[string]Task)
121+
}
122+
if store.Tasks == nil {
123+
return make(map[string]Task)
124+
}
125+
return store.Tasks
126+
}
127+
128+
func (s *FileTaskStorage) save(tasks map[string]Task) error {
129+
if err := os.MkdirAll(filepath.Dir(s.filePath), 0o700); err != nil {
130+
return fmt.Errorf("creating storage directory: %w", err)
131+
}
132+
data, err := json.MarshalIndent(taskStore{Tasks: tasks}, "", " ")
133+
if err != nil {
134+
return fmt.Errorf("marshaling task store: %w", err)
135+
}
136+
return os.WriteFile(s.filePath, data, 0o644)
137+
}
138+
139+
func (s *FileTaskStorage) All() map[string]Task {
140+
s.mu.Lock()
141+
defer s.mu.Unlock()
142+
return s.load()
143+
}
144+
145+
func (s *FileTaskStorage) Get(id string) (Task, bool) {
146+
s.mu.Lock()
147+
defer s.mu.Unlock()
148+
tasks := s.load()
149+
t, ok := tasks[id]
150+
return t, ok
151+
}
152+
153+
func (s *FileTaskStorage) Put(task Task) error {
154+
s.mu.Lock()
155+
defer s.mu.Unlock()
156+
tasks := s.load()
157+
tasks[task.ID] = task
158+
return s.save(tasks)
159+
}
160+
161+
func (s *FileTaskStorage) Delete(id string) error {
162+
s.mu.Lock()
163+
defer s.mu.Unlock()
164+
tasks := s.load()
165+
delete(tasks, id)
166+
return s.save(tasks)
167+
}
168+
169+
// TaskOption is a functional option for configuring a TasksTool.
170+
type TaskOption func(*TasksTool)
171+
172+
// WithTaskStorage sets a custom storage implementation for the TasksTool.
173+
// The provided storage must not be nil.
174+
func WithTaskStorage(storage TaskStorage) TaskOption {
175+
if storage == nil {
176+
panic("tasks: storage must not be nil")
177+
}
178+
return func(t *TasksTool) {
179+
t.storage = storage
180+
}
181+
}
182+
183+
type TasksTool struct {
184+
storage TaskStorage
93185
basePath string
94186
}
95187

@@ -98,11 +190,15 @@ var (
98190
_ tools.Instructable = (*TasksTool)(nil)
99191
)
100192

101-
func NewTasksTool(storagePath string) *TasksTool {
102-
return &TasksTool{
103-
filePath: storagePath,
193+
func NewTasksTool(storagePath string, opts ...TaskOption) *TasksTool {
194+
t := &TasksTool{
195+
storage: NewFileTaskStorage(storagePath),
104196
basePath: filepath.Dir(storagePath),
105197
}
198+
for _, opt := range opts {
199+
opt(t)
200+
}
201+
return t
106202
}
107203

108204
func (t *TasksTool) Instructions() string {
@@ -115,32 +211,6 @@ Tasks are saved to a JSON file and survive across sessions. A task is automatica
115211
Workflow: create_task → list_tasks/next_task → update_task as work progresses. Use add_dependency/remove_dependency to manage ordering.`
116212
}
117213

118-
func (t *TasksTool) load() taskStore {
119-
data, err := os.ReadFile(t.filePath)
120-
if err != nil {
121-
return taskStore{Tasks: make(map[string]Task)}
122-
}
123-
var store taskStore
124-
if err := json.Unmarshal(data, &store); err != nil {
125-
return taskStore{Tasks: make(map[string]Task)}
126-
}
127-
if store.Tasks == nil {
128-
store.Tasks = make(map[string]Task)
129-
}
130-
return store
131-
}
132-
133-
func (t *TasksTool) save(store taskStore) error {
134-
if err := os.MkdirAll(filepath.Dir(t.filePath), 0o700); err != nil {
135-
return fmt.Errorf("creating storage directory: %w", err)
136-
}
137-
data, err := json.MarshalIndent(store, "", " ")
138-
if err != nil {
139-
return fmt.Errorf("marshaling task store: %w", err)
140-
}
141-
return os.WriteFile(t.filePath, data, 0o644)
142-
}
143-
144214
func effectiveStatus(task Task, tasks map[string]Task) TaskStatus {
145215
if task.Status == StatusDone {
146216
return StatusDone
@@ -266,22 +336,19 @@ func (t *TasksTool) createTask(_ context.Context, params CreateTaskArgs) (*tools
266336
return tools.ResultError(fmt.Sprintf("invalid priority: %s", params.Priority)), nil
267337
}
268338

269-
t.mu.Lock()
270-
defer t.mu.Unlock()
271-
272-
store := t.load()
339+
all := t.storage.All()
273340
id := uuid.New().String()
274341

275342
deps := params.Dependencies
276343
if deps == nil {
277344
deps = []string{}
278345
}
279346
for _, depID := range deps {
280-
if _, ok := store.Tasks[depID]; !ok {
347+
if _, ok := all[depID]; !ok {
281348
return tools.ResultError(fmt.Sprintf("dependency task not found: %s", depID)), nil
282349
}
283350
}
284-
if hasCycle(store.Tasks, id, deps) {
351+
if hasCycle(all, id, deps) {
285352
return tools.ResultError("adding these dependencies would create a cycle"), nil
286353
}
287354

@@ -296,33 +363,24 @@ func (t *TasksTool) createTask(_ context.Context, params CreateTaskArgs) (*tools
296363
UpdatedAt: now(),
297364
}
298365

299-
store.Tasks[id] = task
300-
if err := t.save(store); err != nil {
366+
if err := t.storage.Put(task); err != nil {
301367
return tools.ResultError(err.Error()), nil
302368
}
303369

304370
return taskResult(task), nil
305371
}
306372

307373
func (t *TasksTool) getTask(_ context.Context, params GetTaskArgs) (*tools.ToolCallResult, error) {
308-
t.mu.Lock()
309-
defer t.mu.Unlock()
310-
311-
store := t.load()
312-
task, ok := store.Tasks[params.ID]
374+
task, ok := t.storage.Get(params.ID)
313375
if !ok {
314376
return tools.ResultError(fmt.Sprintf("task not found: %s", params.ID)), nil
315377
}
316378

317-
return taskWithEffectiveResult(task, store.Tasks), nil
379+
return taskWithEffectiveResult(task, t.storage.All()), nil
318380
}
319381

320382
func (t *TasksTool) updateTask(_ context.Context, params UpdateTaskArgs) (*tools.ToolCallResult, error) {
321-
t.mu.Lock()
322-
defer t.mu.Unlock()
323-
324-
store := t.load()
325-
task, ok := store.Tasks[params.ID]
383+
task, ok := t.storage.Get(params.ID)
326384
if !ok {
327385
return tools.ResultError(fmt.Sprintf("task not found: %s", params.ID)), nil
328386
}
@@ -350,50 +408,53 @@ func (t *TasksTool) updateTask(_ context.Context, params UpdateTaskArgs) (*tools
350408
task.Status = TaskStatus(params.Status)
351409
}
352410
if params.Dependencies != nil {
411+
all := t.storage.All()
353412
for _, depID := range params.Dependencies {
354-
if _, exists := store.Tasks[depID]; !exists {
413+
if _, exists := all[depID]; !exists {
355414
return tools.ResultError(fmt.Sprintf("dependency task not found: %s", depID)), nil
356415
}
357416
}
358-
if hasCycle(store.Tasks, params.ID, params.Dependencies) {
417+
if hasCycle(all, params.ID, params.Dependencies) {
359418
return tools.ResultError("adding these dependencies would create a cycle"), nil
360419
}
361420
task.Dependencies = params.Dependencies
362421
}
363422

364423
task.UpdatedAt = now()
365-
store.Tasks[params.ID] = task
366424

367-
if err := t.save(store); err != nil {
425+
if err := t.storage.Put(task); err != nil {
368426
return tools.ResultError(err.Error()), nil
369427
}
370428

371429
return taskResult(task), nil
372430
}
373431

374432
func (t *TasksTool) deleteTask(_ context.Context, params DeleteTaskArgs) (*tools.ToolCallResult, error) {
375-
t.mu.Lock()
376-
defer t.mu.Unlock()
377-
378-
store := t.load()
379-
if _, ok := store.Tasks[params.ID]; !ok {
433+
if _, ok := t.storage.Get(params.ID); !ok {
380434
return tools.ResultError(fmt.Sprintf("task not found: %s", params.ID)), nil
381435
}
382436

383-
for id, task := range store.Tasks {
437+
// Remove the task from other tasks' dependency lists.
438+
all := t.storage.All()
439+
for id, task := range all {
440+
if id == params.ID {
441+
continue
442+
}
384443
filtered := make([]string, 0, len(task.Dependencies))
385444
for _, d := range task.Dependencies {
386445
if d != params.ID {
387446
filtered = append(filtered, d)
388447
}
389448
}
390-
task.Dependencies = filtered
391-
store.Tasks[id] = task
449+
if len(filtered) != len(task.Dependencies) {
450+
task.Dependencies = filtered
451+
if err := t.storage.Put(task); err != nil {
452+
return tools.ResultError(err.Error()), nil
453+
}
454+
}
392455
}
393456

394-
delete(store.Tasks, params.ID)
395-
396-
if err := t.save(store); err != nil {
457+
if err := t.storage.Delete(params.ID); err != nil {
397458
return tools.ResultError(err.Error()), nil
398459
}
399460

@@ -405,15 +466,12 @@ func (t *TasksTool) deleteTask(_ context.Context, params DeleteTaskArgs) (*tools
405466
}
406467

407468
func (t *TasksTool) listTasks(_ context.Context, params ListTasksArgs) (*tools.ToolCallResult, error) {
408-
t.mu.Lock()
409-
defer t.mu.Unlock()
410-
411-
store := t.load()
469+
all := t.storage.All()
412470
var tasks []taskWithEffective
413-
for _, task := range store.Tasks {
471+
for _, task := range all {
414472
tasks = append(tasks, taskWithEffective{
415473
Task: task,
416-
EffectiveStatus: effectiveStatus(task, store.Tasks),
474+
EffectiveStatus: effectiveStatus(task, all),
417475
})
418476
}
419477

@@ -446,15 +504,12 @@ func (t *TasksTool) listTasks(_ context.Context, params ListTasksArgs) (*tools.T
446504
}
447505

448506
func (t *TasksTool) nextTask(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) {
449-
t.mu.Lock()
450-
defer t.mu.Unlock()
451-
452-
store := t.load()
507+
all := t.storage.All()
453508
var tasks []taskWithEffective
454-
for _, task := range store.Tasks {
509+
for _, task := range all {
455510
tasks = append(tasks, taskWithEffective{
456511
Task: task,
457-
EffectiveStatus: effectiveStatus(task, store.Tasks),
512+
EffectiveStatus: effectiveStatus(task, all),
458513
})
459514
}
460515
sortTasks(tasks)
@@ -475,43 +530,35 @@ func (t *TasksTool) nextTask(_ context.Context, _ tools.ToolCall) (*tools.ToolCa
475530
}
476531

477532
func (t *TasksTool) addDependency(_ context.Context, params AddDependencyArgs) (*tools.ToolCallResult, error) {
478-
t.mu.Lock()
479-
defer t.mu.Unlock()
480-
481-
store := t.load()
482-
task, ok := store.Tasks[params.TaskID]
533+
task, ok := t.storage.Get(params.TaskID)
483534
if !ok {
484535
return tools.ResultError(fmt.Sprintf("task not found: %s", params.TaskID)), nil
485536
}
486-
if _, ok := store.Tasks[params.DependsOnID]; !ok {
537+
if _, ok := t.storage.Get(params.DependsOnID); !ok {
487538
return tools.ResultError(fmt.Sprintf("dependency task not found: %s", params.DependsOnID)), nil
488539
}
489540
if slices.Contains(task.Dependencies, params.DependsOnID) {
490541
return tools.ResultError("dependency already exists"), nil
491542
}
492543

493544
newDeps := append(task.Dependencies, params.DependsOnID)
494-
if hasCycle(store.Tasks, params.TaskID, newDeps) {
545+
all := t.storage.All()
546+
if hasCycle(all, params.TaskID, newDeps) {
495547
return tools.ResultError("adding this dependency would create a cycle"), nil
496548
}
497549

498550
task.Dependencies = newDeps
499551
task.UpdatedAt = now()
500-
store.Tasks[params.TaskID] = task
501552

502-
if err := t.save(store); err != nil {
553+
if err := t.storage.Put(task); err != nil {
503554
return tools.ResultError(err.Error()), nil
504555
}
505556

506557
return taskResult(task), nil
507558
}
508559

509560
func (t *TasksTool) removeDependency(_ context.Context, params RemoveDependencyArgs) (*tools.ToolCallResult, error) {
510-
t.mu.Lock()
511-
defer t.mu.Unlock()
512-
513-
store := t.load()
514-
task, ok := store.Tasks[params.TaskID]
561+
task, ok := t.storage.Get(params.TaskID)
515562
if !ok {
516563
return tools.ResultError(fmt.Sprintf("task not found: %s", params.TaskID)), nil
517564
}
@@ -524,9 +571,8 @@ func (t *TasksTool) removeDependency(_ context.Context, params RemoveDependencyA
524571
}
525572
task.Dependencies = filtered
526573
task.UpdatedAt = now()
527-
store.Tasks[params.TaskID] = task
528574

529-
if err := t.save(store); err != nil {
575+
if err := t.storage.Put(task); err != nil {
530576
return tools.ResultError(err.Error()), nil
531577
}
532578

0 commit comments

Comments
 (0)