Skip to content

Commit 0049cad

Browse files
fix: Progressive runs ignore cancellation while persisting synthetic follow-up/finalization turns
1 parent f1778cb commit 0049cad

2 files changed

Lines changed: 104 additions & 5 deletions

File tree

cmd/progressive.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ const (
1919
progressiveStopWhenDone progressiveStopWhen = "done"
2020
progressiveStopWhenTimeout progressiveStopWhen = "timeout"
2121

22-
progressiveDefaultFinalizeGrace = 5 * time.Minute
23-
progressiveMaxFinalizeBudget = 5 * time.Minute
24-
progressiveMinFinalizeBudget = 5 * time.Second
22+
progressiveDefaultFinalizeGrace = 5 * time.Minute
23+
progressiveMaxFinalizeBudget = 5 * time.Minute
24+
progressiveMinFinalizeBudget = 5 * time.Second
25+
progressiveSyntheticMessageTimeout = 5 * time.Second
2526
)
2627

2728
type askProgressiveOptions struct {
@@ -296,7 +297,10 @@ func runProgressiveSession(ctx context.Context, engine *llm.Engine, req llm.Requ
296297
continueMsg := llm.UserText(expandProgressiveTemplate(opts.ContinueWith, mainCtx))
297298
history = append(history, continueMsg)
298299
if opts.OnSyntheticUserMessage != nil {
299-
if err := opts.OnSyntheticUserMessage(context.Background(), continueMsg); err != nil {
300+
msgCtx, cancel := progressiveSyntheticUserMessageContext(mainCtx)
301+
err := opts.OnSyntheticUserMessage(msgCtx, continueMsg)
302+
cancel()
303+
if err != nil {
300304
return buildProgressiveRunResult(opts.SessionID, exitReasonNatural, false, tracker.latest, lastText), err
301305
}
302306
}
@@ -425,7 +429,10 @@ func attemptProgressiveFinalization(parentCtx context.Context, engine *llm.Engin
425429
finalPrompt := buildProgressiveFinalizePrompt(tracker.latest)
426430
finalizeMsg := llm.UserText(finalPrompt)
427431
if opts.OnSyntheticUserMessage != nil {
428-
if err := opts.OnSyntheticUserMessage(context.Background(), finalizeMsg); err != nil {
432+
msgCtx, cancel := progressiveSyntheticUserMessageContext(finalizeCtx)
433+
err := opts.OnSyntheticUserMessage(msgCtx, finalizeMsg)
434+
cancel()
435+
if err != nil {
429436
return false, ""
430437
}
431438
}
@@ -513,6 +520,15 @@ func progressiveHasRemainingBudget(parent context.Context, reserve time.Duration
513520
return time.Until(deadline) > reserve
514521
}
515522

523+
// progressiveSyntheticUserMessageContext caps synthetic user-message persistence
524+
// to a short window while still honoring the caller's cancellation/deadline.
525+
func progressiveSyntheticUserMessageContext(parent context.Context) (context.Context, context.CancelFunc) {
526+
if parent == nil {
527+
parent = context.Background()
528+
}
529+
return context.WithTimeout(parent, progressiveSyntheticMessageTimeout)
530+
}
531+
516532
func progressiveFinalizeReserve(total time.Duration) time.Duration {
517533
if total <= 0 {
518534
return 0

cmd/progressive_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package cmd
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"testing"
78
"time"
89

910
"github.com/samsaffron/term-llm/internal/llm"
11+
toolpkg "github.com/samsaffron/term-llm/internal/tools"
1012
)
1113

1214
type progressiveTestTool struct{}
@@ -175,6 +177,87 @@ func TestRunProgressiveSessionTimeoutDoesNotStartDetachedFinalizationAfterDeadli
175177
}
176178
}
177179

180+
func TestRunProgressiveSessionSyntheticContinuationPersistenceHonorsContext(t *testing.T) {
181+
provider := llm.NewMockProvider("mock").WithCapabilities(llm.Capabilities{
182+
ToolCalls: true,
183+
SupportsToolChoice: true,
184+
})
185+
provider.AddToolCall("progress-1", "update_progress", map[string]any{
186+
"state": map[string]any{
187+
"step": "draft",
188+
},
189+
"reason": "milestone",
190+
"message": "draft saved",
191+
})
192+
provider.AddTextResponse("draft answer")
193+
194+
engine := llm.NewEngine(provider, nil)
195+
ctx, cancel := context.WithTimeout(context.Background(), 5050*time.Millisecond)
196+
defer cancel()
197+
198+
done := make(chan error, 1)
199+
go func() {
200+
_, err := runProgressiveSession(ctx, engine, llm.Request{
201+
Messages: []llm.Message{llm.UserText("Investigate X")},
202+
MaxTurns: 8,
203+
ToolChoice: llm.ToolChoice{
204+
Mode: llm.ToolChoiceAuto,
205+
},
206+
}, progressiveRunOptions{
207+
StopWhen: progressiveStopWhenTimeout,
208+
OnSyntheticUserMessage: func(ctx context.Context, msg llm.Message) error {
209+
<-ctx.Done()
210+
return ctx.Err()
211+
},
212+
})
213+
done <- err
214+
}()
215+
216+
select {
217+
case err := <-done:
218+
if !errors.Is(err, context.DeadlineExceeded) {
219+
t.Fatalf("runProgressiveSession error = %v, want deadline exceeded", err)
220+
}
221+
case <-time.After(500 * time.Millisecond):
222+
t.Fatal("synthetic continuation persistence did not stop when its context expired")
223+
}
224+
}
225+
226+
func TestAttemptProgressiveFinalizationSyntheticPersistenceHonorsContext(t *testing.T) {
227+
provider := llm.NewMockProvider("mock")
228+
engine := llm.NewEngine(provider, nil)
229+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
230+
defer cancel()
231+
232+
done := make(chan struct{})
233+
go func() {
234+
finalized, text := attemptProgressiveFinalization(ctx, engine, toolpkg.NewFinalizeProgressTool(), llm.Request{
235+
Messages: []llm.Message{llm.UserText("Investigate X")},
236+
}, nil, progressiveRunOptions{
237+
OnSyntheticUserMessage: func(ctx context.Context, msg llm.Message) error {
238+
<-ctx.Done()
239+
return ctx.Err()
240+
},
241+
}, newProgressTracker(), 20*time.Millisecond, exitReasonTimeout)
242+
if finalized {
243+
t.Errorf("finalized = true, want false")
244+
}
245+
if text != "" {
246+
t.Errorf("text = %q, want empty", text)
247+
}
248+
close(done)
249+
}()
250+
251+
select {
252+
case <-done:
253+
if len(provider.Requests) != 0 {
254+
t.Fatalf("provider saw %d requests, want synthetic persistence failure to stop finalization before the model call", len(provider.Requests))
255+
}
256+
case <-time.After(500 * time.Millisecond):
257+
t.Fatal("synthetic finalization persistence did not stop when its context expired")
258+
}
259+
}
260+
178261
func TestProgressiveFinalizationContextNaturalCompletionDetachesFromParent(t *testing.T) {
179262
parent, cancelParent := context.WithCancel(context.Background())
180263
finalizeCtx, finalizeCancel := progressiveFinalizationContext(parent, time.Second, exitReasonNatural)

0 commit comments

Comments
 (0)