Skip to content

Commit a015bd0

Browse files
committed
fix(jobs): honor cancellation in terminal updates
Preserve cancel-requested runs as cancelled when late completion arrives, and keep their cancellation metadata intact. Also cap synthetic progressive-message persistence with the caller's context so timeout handling stops promptly instead of hanging past deadlines.
1 parent 1eb1eee commit a015bd0

6 files changed

Lines changed: 272 additions & 13 deletions

File tree

cmd/progressive.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ const (
1919
progressiveStopWhenDone progressiveStopWhen = "done"
2020
progressiveStopWhenTimeout progressiveStopWhen = "timeout"
2121

22-
progressiveDefaultFinalizeGrace = 5 * time.Minute
23-
progressiveMaxFinalizeBudget = 5 * time.Minute
24-
progressiveMinFinalizeBudget = 5 * time.Second
22+
progressiveDefaultFinalizeGrace = 5 * time.Minute
23+
progressiveMaxFinalizeBudget = 5 * time.Minute
24+
progressiveMinFinalizeBudget = 5 * time.Second
25+
progressiveSyntheticMessageTimeout = 5 * time.Second
2526
)
2627

2728
type askProgressiveOptions struct {
@@ -296,7 +297,10 @@ func runProgressiveSession(ctx context.Context, engine *llm.Engine, req llm.Requ
296297
continueMsg := llm.UserText(expandProgressiveTemplate(opts.ContinueWith, mainCtx))
297298
history = append(history, continueMsg)
298299
if opts.OnSyntheticUserMessage != nil {
299-
if err := opts.OnSyntheticUserMessage(context.Background(), continueMsg); err != nil {
300+
msgCtx, cancel := progressiveSyntheticUserMessageContext(mainCtx)
301+
err := opts.OnSyntheticUserMessage(msgCtx, continueMsg)
302+
cancel()
303+
if err != nil {
300304
return buildProgressiveRunResult(opts.SessionID, exitReasonNatural, false, tracker.latest, lastText), err
301305
}
302306
}
@@ -425,7 +429,10 @@ func attemptProgressiveFinalization(parentCtx context.Context, engine *llm.Engin
425429
finalPrompt := buildProgressiveFinalizePrompt(tracker.latest)
426430
finalizeMsg := llm.UserText(finalPrompt)
427431
if opts.OnSyntheticUserMessage != nil {
428-
if err := opts.OnSyntheticUserMessage(context.Background(), finalizeMsg); err != nil {
432+
msgCtx, cancel := progressiveSyntheticUserMessageContext(finalizeCtx)
433+
err := opts.OnSyntheticUserMessage(msgCtx, finalizeMsg)
434+
cancel()
435+
if err != nil {
429436
return false, ""
430437
}
431438
}
@@ -513,6 +520,15 @@ func progressiveHasRemainingBudget(parent context.Context, reserve time.Duration
513520
return time.Until(deadline) > reserve
514521
}
515522

523+
// progressiveSyntheticUserMessageContext caps synthetic user-message persistence
524+
// to a short window while still honoring the caller's cancellation/deadline.
525+
func progressiveSyntheticUserMessageContext(parent context.Context) (context.Context, context.CancelFunc) {
526+
if parent == nil {
527+
parent = context.Background()
528+
}
529+
return context.WithTimeout(parent, progressiveSyntheticMessageTimeout)
530+
}
531+
516532
func progressiveFinalizeReserve(total time.Duration) time.Duration {
517533
if total <= 0 {
518534
return 0

cmd/progressive_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package cmd
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"testing"
78
"time"
89

910
"github.com/samsaffron/term-llm/internal/llm"
11+
toolpkg "github.com/samsaffron/term-llm/internal/tools"
1012
)
1113

1214
type progressiveTestTool struct{}
@@ -175,6 +177,87 @@ func TestRunProgressiveSessionTimeoutDoesNotStartDetachedFinalizationAfterDeadli
175177
}
176178
}
177179

180+
func TestRunProgressiveSessionSyntheticContinuationPersistenceHonorsContext(t *testing.T) {
181+
provider := llm.NewMockProvider("mock").WithCapabilities(llm.Capabilities{
182+
ToolCalls: true,
183+
SupportsToolChoice: true,
184+
})
185+
provider.AddToolCall("progress-1", "update_progress", map[string]any{
186+
"state": map[string]any{
187+
"step": "draft",
188+
},
189+
"reason": "milestone",
190+
"message": "draft saved",
191+
})
192+
provider.AddTextResponse("draft answer")
193+
194+
engine := llm.NewEngine(provider, nil)
195+
ctx, cancel := context.WithTimeout(context.Background(), 5050*time.Millisecond)
196+
defer cancel()
197+
198+
done := make(chan error, 1)
199+
go func() {
200+
_, err := runProgressiveSession(ctx, engine, llm.Request{
201+
Messages: []llm.Message{llm.UserText("Investigate X")},
202+
MaxTurns: 8,
203+
ToolChoice: llm.ToolChoice{
204+
Mode: llm.ToolChoiceAuto,
205+
},
206+
}, progressiveRunOptions{
207+
StopWhen: progressiveStopWhenTimeout,
208+
OnSyntheticUserMessage: func(ctx context.Context, msg llm.Message) error {
209+
<-ctx.Done()
210+
return ctx.Err()
211+
},
212+
})
213+
done <- err
214+
}()
215+
216+
select {
217+
case err := <-done:
218+
if !errors.Is(err, context.DeadlineExceeded) {
219+
t.Fatalf("runProgressiveSession error = %v, want deadline exceeded", err)
220+
}
221+
case <-time.After(500 * time.Millisecond):
222+
t.Fatal("synthetic continuation persistence did not stop when its context expired")
223+
}
224+
}
225+
226+
func TestAttemptProgressiveFinalizationSyntheticPersistenceHonorsContext(t *testing.T) {
227+
provider := llm.NewMockProvider("mock")
228+
engine := llm.NewEngine(provider, nil)
229+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
230+
defer cancel()
231+
232+
done := make(chan struct{})
233+
go func() {
234+
finalized, text := attemptProgressiveFinalization(ctx, engine, toolpkg.NewFinalizeProgressTool(), llm.Request{
235+
Messages: []llm.Message{llm.UserText("Investigate X")},
236+
}, nil, progressiveRunOptions{
237+
OnSyntheticUserMessage: func(ctx context.Context, msg llm.Message) error {
238+
<-ctx.Done()
239+
return ctx.Err()
240+
},
241+
}, newProgressTracker(), 20*time.Millisecond, exitReasonTimeout)
242+
if finalized {
243+
t.Errorf("finalized = true, want false")
244+
}
245+
if text != "" {
246+
t.Errorf("text = %q, want empty", text)
247+
}
248+
close(done)
249+
}()
250+
251+
select {
252+
case <-done:
253+
if len(provider.Requests) != 0 {
254+
t.Fatalf("provider saw %d requests, want synthetic persistence failure to stop finalization before the model call", len(provider.Requests))
255+
}
256+
case <-time.After(500 * time.Millisecond):
257+
t.Fatal("synthetic finalization persistence did not stop when its context expired")
258+
}
259+
}
260+
178261
func TestProgressiveFinalizationContextNaturalCompletionDetachesFromParent(t *testing.T) {
179262
parent, cancelParent := context.WithCancel(context.Background())
180263
finalizeCtx, finalizeCancel := progressiveFinalizationContext(parent, time.Second, exitReasonNatural)

cmd/serve_jobs_v2.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,13 +1365,32 @@ func (m *jobsV2Manager) finishRun(runID string, status jobsV2RunStatus, result j
13651365
if runErr != nil {
13661366
errText = runErr.Error()
13671367
}
1368-
_, err := m.db.Exec(`UPDATE job_runs_v2 SET status = ?, finished_at = ?, exit_code = ?, error = ?, stdout = ?, stderr = ?, thinking = ?, response = ?, session_id = ?, exit_reason = ?, truncated = ?, turn_count = ?, input_tokens = ?, output_tokens = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`,
1369-
status, now, result.ExitCode, errText, result.Stdout, result.Stderr, result.Thinking, result.Response, result.SessionID,
1370-
exitReason, boolToInt(truncated), result.TurnCount, result.InputTokens, result.OutputTokens,
1371-
runID)
1368+
cancelledErrText := context.Canceled.Error()
1369+
res, err := m.db.Exec(`UPDATE job_runs_v2 SET status = CASE WHEN status = ? THEN ? ELSE ? END, finished_at = ?, exit_code = ?, error = CASE WHEN status = ? THEN ? ELSE ? END, stdout = ?, stderr = ?, thinking = ?, response = ?, session_id = ?, exit_reason = CASE WHEN status = ? THEN ? ELSE ? END, truncated = ?, turn_count = ?, input_tokens = ?, output_tokens = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? AND status IN (?, ?, ?, ?)`,
1370+
jobsV2RunCancelRequested, jobsV2RunCancelled, status,
1371+
now, result.ExitCode,
1372+
jobsV2RunCancelRequested, cancelledErrText, errText,
1373+
result.Stdout, result.Stderr, result.Thinking, result.Response, result.SessionID,
1374+
jobsV2RunCancelRequested, exitReasonCancelled, exitReason,
1375+
boolToInt(truncated), result.TurnCount, result.InputTokens, result.OutputTokens,
1376+
runID, jobsV2RunQueued, jobsV2RunClaimed, jobsV2RunRunning, jobsV2RunCancelRequested)
13721377
if err != nil {
13731378
return err
13741379
}
1380+
affected, _ := res.RowsAffected()
1381+
if affected == 0 {
1382+
return nil
1383+
}
1384+
1385+
run, err := m.GetRun(runID)
1386+
if err != nil {
1387+
return err
1388+
}
1389+
status = run.Status
1390+
exitReason = run.ExitReason
1391+
truncated = run.Truncated
1392+
errText = run.Error
1393+
13751394
_ = m.addRunEvent(runID, string(status), "run finished", map[string]any{
13761395
"status": status,
13771396
"attempt": attempt,
@@ -1387,10 +1406,6 @@ func (m *jobsV2Manager) finishRun(runID string, status jobsV2RunStatus, result j
13871406
m.notifyRunDone(runID, status, result, exitReason, truncated, errText)
13881407

13891408
if status == jobsV2RunFailed || status == jobsV2RunTimedOut {
1390-
run, err := m.GetRun(runID)
1391-
if err != nil {
1392-
return nil
1393-
}
13941409
job, err := m.GetJob(run.JobID)
13951410
if err != nil {
13961411
return nil

cmd/serve_jobs_v2_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"encoding/json"
7+
"errors"
78
"net/http"
89
"net/http/httptest"
910
"path/filepath"
@@ -303,6 +304,81 @@ func TestJobsV2ManualTriggerAndCancel(t *testing.T) {
303304
}
304305
}
305306

307+
func TestJobsV2FinishRunDoesNotOverrideCancelRequested(t *testing.T) {
308+
mgr, err := newJobsV2Manager(":memory:", 0, nil)
309+
if err != nil {
310+
t.Fatalf("newJobsV2Manager failed: %v", err)
311+
}
312+
defer func() { _ = mgr.Close() }()
313+
314+
job, err := mgr.CreateJob(jobsV2Job{
315+
Name: "cancel-authoritative",
316+
Enabled: true,
317+
RunnerType: jobsV2RunnerProgram,
318+
RunnerConfig: json.RawMessage(`{"command":"echo","args":["x"]}`),
319+
TriggerType: jobsV2TriggerManual,
320+
TriggerConfig: json.RawMessage(`{}`),
321+
RetryPolicy: json.RawMessage(`{"max_attempts":2}`),
322+
})
323+
if err != nil {
324+
t.Fatalf("CreateJob failed: %v", err)
325+
}
326+
if job.ID == "" {
327+
t.Fatal("expected created job to have an id")
328+
}
329+
330+
run, err := mgr.TriggerJob(job.ID)
331+
if err != nil {
332+
t.Fatalf("TriggerJob failed: %v", err)
333+
}
334+
335+
started := time.Now().UTC()
336+
if _, err := mgr.db.Exec(`UPDATE job_runs_v2 SET status = ?, started_at = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, jobsV2RunRunning, started, run.ID); err != nil {
337+
t.Fatalf("mark run running: %v", err)
338+
}
339+
340+
cancelled, err := mgr.CancelRun(run.ID)
341+
if err != nil {
342+
t.Fatalf("CancelRun failed: %v", err)
343+
}
344+
if cancelled.Status != jobsV2RunCancelRequested {
345+
t.Fatalf("cancelled status = %s, want %s", cancelled.Status, jobsV2RunCancelRequested)
346+
}
347+
348+
result := jobsV2RunResult{Response: "partial output", ExitCode: 17}
349+
if err := mgr.finishRun(run.ID, jobsV2RunFailed, result, errors.New("runner failed after cancel request"), run.Attempt); err != nil {
350+
t.Fatalf("finishRun failed: %v", err)
351+
}
352+
353+
current, err := mgr.GetRun(run.ID)
354+
if err != nil {
355+
t.Fatalf("GetRun failed: %v", err)
356+
}
357+
if current.Status != jobsV2RunCancelled {
358+
t.Fatalf("run status = %s, want %s", current.Status, jobsV2RunCancelled)
359+
}
360+
if current.ExitReason != exitReasonCancelled {
361+
t.Fatalf("exit reason = %q, want %q", current.ExitReason, exitReasonCancelled)
362+
}
363+
if current.Error != context.Canceled.Error() {
364+
t.Fatalf("error = %q, want %q", current.Error, context.Canceled.Error())
365+
}
366+
if current.Response != result.Response {
367+
t.Fatalf("response = %q, want %q", current.Response, result.Response)
368+
}
369+
if current.FinishedAt == nil {
370+
t.Fatal("expected cancelled run to have finished_at set")
371+
}
372+
373+
var retryRuns int
374+
if err := mgr.db.QueryRow(`SELECT COUNT(*) FROM job_runs_v2 WHERE job_id = ? AND id != ?`, job.ID, run.ID).Scan(&retryRuns); err != nil {
375+
t.Fatalf("count retry runs: %v", err)
376+
}
377+
if retryRuns != 0 {
378+
t.Fatalf("retry runs = %d, want 0", retryRuns)
379+
}
380+
}
381+
306382
func TestJobsV2CreateDefaultsEnabledWhenOmitted(t *testing.T) {
307383
mgr, err := newJobsV2Manager(":memory:", 0, nil)
308384
if err != nil {

cmd/serve_response_runs.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,19 @@ func (r *responseRun) appendEvent(event string, payload map[string]any) error {
127127
func (r *responseRun) complete(payload map[string]any, usage llm.Usage, sessionUsage llm.Usage) error {
128128
r.mu.Lock()
129129
defer r.mu.Unlock()
130+
if r.cancelRequested {
131+
r.status = "cancelled"
132+
r.errorType = ""
133+
r.errorMessage = ""
134+
r.cancel = nil
135+
r.cancelRequested = false
136+
if response := mapValue(payload["response"]); len(response) > 0 {
137+
response["status"] = "cancelled"
138+
delete(response, "usage")
139+
delete(response, "session_usage")
140+
}
141+
return r.appendEventLocked("response.cancelled", payload, true)
142+
}
130143
r.status = "completed"
131144
r.errorType = ""
132145
r.errorMessage = ""

cmd/serve_response_runs_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,62 @@ func TestResponseRunSubscriberDroppedWhenBufferFull(t *testing.T) {
129129
}
130130
}
131131

132+
func TestResponseRunCompleteHonorsPendingCancellation(t *testing.T) {
133+
run := newResponseRun("resp_cancelled", "sess_test", "", "mock", time.Now().Unix(), func() {})
134+
135+
if !run.cancelRun() {
136+
t.Fatal("expected cancelRun to succeed")
137+
}
138+
139+
if err := run.complete(map[string]any{
140+
"response": map[string]any{
141+
"id": run.id,
142+
"object": "response",
143+
"created": run.created,
144+
"model": run.model,
145+
"status": "completed",
146+
"usage": usagePayload(llm.Usage{InputTokens: 3, OutputTokens: 4}),
147+
"session_usage": usagePayload(llm.Usage{InputTokens: 5, OutputTokens: 6}),
148+
},
149+
}, llm.Usage{InputTokens: 3, OutputTokens: 4}, llm.Usage{InputTokens: 5, OutputTokens: 6}); err != nil {
150+
t.Fatalf("complete failed: %v", err)
151+
}
152+
153+
run.mu.Lock()
154+
defer run.mu.Unlock()
155+
156+
if run.status != "cancelled" {
157+
t.Fatalf("run status = %q, want cancelled", run.status)
158+
}
159+
if run.cancelRequested {
160+
t.Fatal("cancelRequested should be cleared after terminal transition")
161+
}
162+
if len(run.events) != 1 {
163+
t.Fatalf("events = %d, want 1", len(run.events))
164+
}
165+
if run.events[0].Event != "response.cancelled" {
166+
t.Fatalf("event = %q, want response.cancelled", run.events[0].Event)
167+
}
168+
169+
var payload map[string]any
170+
if err := json.Unmarshal(run.events[0].Data, &payload); err != nil {
171+
t.Fatalf("unmarshal terminal payload: %v", err)
172+
}
173+
response, ok := payload["response"].(map[string]any)
174+
if !ok {
175+
t.Fatalf("response payload type = %T", payload["response"])
176+
}
177+
if response["status"] != "cancelled" {
178+
t.Fatalf("response status = %v, want cancelled", response["status"])
179+
}
180+
if _, ok := response["usage"]; ok {
181+
t.Fatalf("cancelled payload unexpectedly retained usage: %#v", response["usage"])
182+
}
183+
if _, ok := response["session_usage"]; ok {
184+
t.Fatalf("cancelled payload unexpectedly retained session_usage: %#v", response["session_usage"])
185+
}
186+
}
187+
132188
func TestResponseRunCompactionKeepsReplayWindowInOrder(t *testing.T) {
133189
run := newResponseRun("resp_compact", "sess_test", "", "mock", time.Now().Unix(), func() {})
134190
run.maxRetainedEvents = 3

0 commit comments

Comments
 (0)