Skip to content

Commit accdbdb

Browse files
Fix interrupt handling for suspended tasks
Previously, task suspension did not properly cancel running task execution, causing tasks to continue processing even after being marked as suspended. Key changes: • Add TaskSuspendedEvent to distinguish suspension from normal task events • Subscribe to suspension events in TaskReconciler to cancel active tasks • Resume tasks when new messages arrive by resetting DesiredPhase to running • Check DesiredPhase instead of Phase when computing task status This ensures tasks can be properly interrupted and later resumed without leaving zombie executions running in the background. Co-authored-by: construct-agent <noreply@construct.sh>
1 parent c2923ff commit accdbdb

6 files changed

Lines changed: 28 additions & 6 deletions

File tree

backend/agent/conv.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,8 @@ func ConvertMemoryMessageToProto(m *memory.Message) (*v1.Message, error) {
581581
InputTokens: m.Usage.InputTokens,
582582
OutputTokens: m.Usage.OutputTokens,
583583
CacheWriteTokens: m.Usage.CacheWriteTokens,
584+
CacheReadTokens: m.Usage.CacheReadTokens,
585+
Cost: m.Usage.Cost,
584586
}
585587
}
586588

backend/agent/task_reconciler.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,20 @@ func (r *TaskReconciler) Run(ctx context.Context) error {
103103
}()
104104
}
105105

106-
sub := event.Subscribe(r.bus, func(ctx context.Context, e event.TaskEvent) {
106+
taskEventSub := event.Subscribe(r.bus, func(ctx context.Context, e event.TaskEvent) {
107107
r.queue.Add(e.TaskID)
108108
}, nil)
109109

110+
taskSuspendedEventSub := event.Subscribe(r.bus, func(ctx context.Context, e event.TaskSuspendedEvent) {
111+
cancel, ok := r.runningTasks.Get(e.TaskID)
112+
if ok {
113+
cancel()
114+
}
115+
}, nil)
116+
110117
<-ctx.Done()
111-
sub.Unsubscribe()
118+
taskEventSub.Unsubscribe()
119+
taskSuspendedEventSub.Unsubscribe()
112120

113121
r.queue.ShutDownWithDrain()
114122

@@ -202,7 +210,6 @@ func (r *TaskReconciler) reconcile(ctx context.Context, taskID uuid.UUID) (Resul
202210
return Result{}, fmt.Errorf("failed to fetch messages: %w", err)
203211
}
204212

205-
// Trigger title generation if needed
206213
if shouldGenerateTitle(task, messages) {
207214
go r.generateTitleAsync(taskID)
208215
}
@@ -256,7 +263,7 @@ func (r *TaskReconciler) fetchTaskWithAgent(ctx context.Context, taskID uuid.UUI
256263

257264
// computeStatus analyzes the message history and determines what action to take
258265
func (r *TaskReconciler) computeStatus(task *memory.Task, messages []*memory.Message) (*TaskStatus, error) {
259-
if task.Phase == types.TaskPhaseSuspended {
266+
if task.DesiredPhase == types.TaskPhaseSuspended {
260267
return &TaskStatus{Phase: TaskPhaseSuspended}, nil
261268
}
262269

backend/api/message.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ func (h *MessageHandler) CreateMessage(ctx context.Context, req *connect.Request
4747
return nil, err
4848
}
4949

50+
if task.DesiredPhase == types.TaskPhaseSuspended {
51+
_, err = tx.Task.UpdateOneID(taskID).SetDesiredPhase(types.TaskPhaseRunning).Save(ctx)
52+
if err != nil {
53+
return nil, err
54+
}
55+
}
56+
5057
return tx.Message.Create().
5158
SetTask(task).
5259
SetContent(conv.ConvertProtoContentToMemory(req.Msg.Content)).

backend/api/task.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ func (h *TaskHandler) SuspendTask(ctx context.Context, req *connect.Request[v1.S
310310
return nil, apiError(err)
311311
}
312312

313-
event.Publish(h.eventBus, event.TaskEvent{
313+
event.Publish(h.eventBus, event.TaskSuspendedEvent{
314314
TaskID: taskID,
315315
})
316316
return connect.NewResponse(&v1.SuspendTaskResponse{}), nil

backend/event/events.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ type TaskEvent struct {
88

99
func (TaskEvent) Event() {}
1010

11+
type TaskSuspendedEvent struct {
12+
TaskID uuid.UUID
13+
}
14+
15+
func (TaskSuspendedEvent) Event() {}
16+
1117
type MessageEvent struct {
1218
MessageID uuid.UUID
1319
TaskID uuid.UUID

shared/services.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ func (u *DefaultUserInfo) ConstructLogDir() (string, error) {
199199
default:
200200
logDir = filepath.Join(xdg.StateHome, "construct")
201201
}
202-
202+
203203
if err := u.fs.MkdirAll(logDir, 0700); err != nil {
204204
return "", fmt.Errorf("failed to create log directory: %w", err)
205205
}

0 commit comments

Comments
 (0)