Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmd/ask.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,9 @@ func runAsk(cmd *cobra.Command, args []string) error {
default:
err = streamPlainText(displayCtx, events, false)
}
if err != nil && bridge != nil {
bridge.Stop()
}
progressiveRun = <-runCh
}
tools.ClearAskUserHooks() // Safe to call even if hooks weren't set
Expand Down
79 changes: 64 additions & 15 deletions cmd/ask_progressive.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package cmd

import (
"errors"
"sync"

"github.com/samsaffron/term-llm/internal/llm"
internalreasoning "github.com/samsaffron/term-llm/internal/reasoning"
"github.com/samsaffron/term-llm/internal/ui"
)

var errAskProgressiveBridgeStopped = errors.New("progressive stream consumer stopped")

type askProgressiveRunResult struct {
Result progressiveRunResult
Err error
Expand All @@ -16,9 +19,11 @@ type askProgressiveRunResult struct {
type askProgressiveBridge struct {
events chan ui.StreamEvent
stats *ui.SessionStats
done chan struct{}

seenToolStarts map[string]struct{}
seenToolEnds map[string]struct{}
stopOnce sync.Once
closeOnce sync.Once

attemptInput int
Expand All @@ -36,6 +41,7 @@ func newAskProgressiveBridge(bufSize int) *askProgressiveBridge {
return &askProgressiveBridge{
events: make(chan ui.StreamEvent, bufSize),
stats: ui.NewSessionStats(),
done: make(chan struct{}),
seenToolStarts: make(map[string]struct{}),
seenToolEnds: make(map[string]struct{}),
}
Expand All @@ -59,16 +65,35 @@ func (b *askProgressiveBridge) Stats() *ui.SessionStats {
return b.stats
}

func (b *askProgressiveBridge) Stop() {
b.stopOnce.Do(func() {
close(b.done)
})
}

func (b *askProgressiveBridge) send(event ui.StreamEvent) error {
select {
case <-b.done:
return errAskProgressiveBridgeStopped
case b.events <- event:
return nil
}
}

func (b *askProgressiveBridge) HandleEvent(event llm.Event) error {
switch event.Type {
case llm.EventError:
if event.Err != nil {
b.events <- ui.ErrorEvent(event.Err)
if err := b.send(ui.ErrorEvent(event.Err)); err != nil {
return err
}
}
case llm.EventTextDelta:
b.attemptUsageCommitted = false
if event.Text != "" {
b.events <- ui.TextEvent(event.Text)
if err := b.send(ui.TextEvent(event.Text)); err != nil {
return err
}
}
case llm.EventReasoningDelta:
b.attemptUsageCommitted = false
Expand All @@ -84,7 +109,9 @@ func (b *askProgressiveBridge) HandleEvent(event llm.Event) error {
if kind == llm.ReasoningKindSummary && event.Text != "" {
title = internalreasoning.ParseReasoningSummary(event.Text).Title
}
b.events <- ui.ReasoningEvent(kind, event.Text, title, event.ReasoningItemID, event.ReasoningFinal, displayable)
if err := b.send(ui.ReasoningEvent(kind, event.Text, title, event.ReasoningItemID, event.ReasoningFinal, displayable)); err != nil {
return err
}
case llm.EventToolCall:
b.markAttemptCommitted()
if event.Tool == nil {
Expand Down Expand Up @@ -112,7 +139,9 @@ func (b *askProgressiveBridge) HandleEvent(event llm.Event) error {
toolArgs = event.Tool.Arguments
}
b.stats.ToolStart()
b.events <- ui.ToolStartEvent(toolCallID, event.Tool.Name, toolInfo, toolArgs)
if err := b.send(ui.ToolStartEvent(toolCallID, event.Tool.Name, toolInfo, toolArgs)); err != nil {
return err
}
case llm.EventToolExecStart:
b.markAttemptCommitted()
if isProgressToolName(event.ToolName) {
Expand All @@ -125,7 +154,9 @@ func (b *askProgressiveBridge) HandleEvent(event llm.Event) error {
b.seenToolStarts[event.ToolCallID] = struct{}{}
}
b.stats.ToolStart()
b.events <- ui.ToolStartEvent(event.ToolCallID, event.ToolName, event.ToolInfo, event.ToolArgs)
if err := b.send(ui.ToolStartEvent(event.ToolCallID, event.ToolName, event.ToolInfo, event.ToolArgs)); err != nil {
return err
}
case llm.EventToolExecEnd:
if isProgressToolName(event.ToolName) {
break
Expand All @@ -138,12 +169,18 @@ func (b *askProgressiveBridge) HandleEvent(event llm.Event) error {
}
b.stats.ToolEnd()
b.resetAttemptUsage()
b.events <- ui.ToolEndEvent(event.ToolCallID, event.ToolName, event.ToolInfo, event.ToolSuccess)
if err := b.send(ui.ToolEndEvent(event.ToolCallID, event.ToolName, event.ToolInfo, event.ToolSuccess)); err != nil {
return err
}
for _, imagePath := range event.ToolImages {
b.events <- ui.ImageEvent(imagePath)
if err := b.send(ui.ImageEvent(imagePath)); err != nil {
return err
}
}
for _, d := range event.ToolDiffs {
b.events <- ui.DiffEventWithOperation(d.File, d.Old, d.New, d.Line, d.Operation)
if err := b.send(ui.DiffEventWithOperation(d.File, d.Old, d.New, d.Line, d.Operation)); err != nil {
return err
}
}
case llm.EventUsage:
if event.Use != nil {
Expand All @@ -155,40 +192,52 @@ func (b *askProgressiveBridge) HandleEvent(event llm.Event) error {
b.attemptCacheWrite += event.Use.CacheWriteTokens
b.attemptUsageCalls++
}
b.events <- ui.UsageEvent(event.Use.InputTokens, event.Use.OutputTokens, event.Use.CachedInputTokens, event.Use.CacheWriteTokens)
if err := b.send(ui.UsageEvent(event.Use.InputTokens, event.Use.OutputTokens, event.Use.CachedInputTokens, event.Use.CacheWriteTokens)); err != nil {
return err
}
}
case llm.EventPhase:
if event.Text != "" {
b.events <- ui.PhaseEvent(event.Text)
if err := b.send(ui.PhaseEvent(event.Text)); err != nil {
return err
}
}
case llm.EventRetry:
b.events <- ui.RetryEvent(event.RetryAttempt, event.RetryMaxAttempts, event.RetryWaitSecs)
if err := b.send(ui.RetryEvent(event.RetryAttempt, event.RetryMaxAttempts, event.RetryWaitSecs)); err != nil {
return err
}
case llm.EventAttemptDiscard:
if b.attemptUsageCalls > 0 {
b.stats.DiscardUsage(b.attemptInput, b.attemptOutput, b.attemptCached, b.attemptCacheWrite, b.attemptUsageCalls)
}
b.resetAttemptUsage()
b.events <- ui.AttemptDiscardEvent()
if err := b.send(ui.AttemptDiscardEvent()); err != nil {
return err
}
case llm.EventInterjection:
if event.Text != "" {
b.events <- ui.InterjectionEvent(event.Text, event.InterjectionID)
if err := b.send(ui.InterjectionEvent(event.Text, event.InterjectionID)); err != nil {
return err
}
}
}
return nil
}

func (b *askProgressiveBridge) CloseSuccess() {
b.closeOnce.Do(func() {
b.events <- ui.DoneEvent(b.stats.OutputTokens)
_ = b.send(ui.DoneEvent(b.stats.OutputTokens))
b.Stop()
close(b.events)
})
}

func (b *askProgressiveBridge) CloseError(err error) {
b.closeOnce.Do(func() {
if err != nil {
b.events <- ui.ErrorEvent(err)
_ = b.send(ui.ErrorEvent(err))
}
b.Stop()
close(b.events)
})
}
54 changes: 54 additions & 0 deletions cmd/ask_progressive_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package cmd

import (
"context"
"errors"
"fmt"
"testing"
"time"

"github.com/samsaffron/term-llm/internal/llm"
"github.com/samsaffron/term-llm/internal/ui"
Expand Down Expand Up @@ -29,3 +33,53 @@ func TestAskProgressiveBridge_PropagatesInterjectionID(t *testing.T) {
t.Fatalf("event interjection ID = %q, want %q", ev.InterjectionID, "bridge-interject-1")
}
}

type failingJSONWriter struct{}

func (failingJSONWriter) Write(p []byte) (int, error) {
return 0, fmt.Errorf("broken pipe")
}

func TestAskProgressiveBridge_StopUnblocksProducerAfterConsumerWriteError(t *testing.T) {
bridge := newAskProgressiveBridge(1)
runCh := make(chan error, 1)

go func() {
for i := 0; i < 8; i++ {
err := bridge.HandleEvent(llm.Event{Type: llm.EventTextDelta, Text: "chunk"})
if err != nil {
bridge.CloseError(err)
runCh <- err
return
}
}
bridge.CloseSuccess()
runCh <- nil
}()

_, _, writeErr := streamJSONEvents(context.Background(), bridge.Events(), newJSONEmitter(failingJSONWriter{}))
if writeErr == nil {
t.Fatal("expected streamJSONEvents to fail")
}

bridge.Stop()

select {
case err := <-runCh:
if !errors.Is(err, errAskProgressiveBridgeStopped) {
t.Fatalf("producer error = %v, want %v", err, errAskProgressiveBridgeStopped)
}
case <-time.After(2 * time.Second):
t.Fatal("producer remained blocked after bridge.Stop")
}

select {
case _, ok := <-bridge.Events():
if ok {
for range bridge.Events() {
}
}
case <-time.After(2 * time.Second):
t.Fatal("bridge events channel was not closed")
}
}
Loading