Skip to content

Commit 7f7b53e

Browse files
committed
Improve compaction
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent d7f1f6e commit 7f7b53e

3 files changed

Lines changed: 229 additions & 0 deletions

File tree

pkg/runtime/runtime.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,8 +988,17 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
988988

989989
if m != nil && r.sessionCompaction {
990990
if sess.InputTokens+sess.OutputTokens > int64(float64(contextLimit)*0.9) {
991+
messageCountBefore := len(sess.Messages)
991992
r.Summarize(ctx, sess, "", events)
992993
events <- TokenUsage(sess.ID, r.currentAgent, sess.InputTokens, sess.OutputTokens, sess.InputTokens+sess.OutputTokens, contextLimit, sess.Cost)
994+
995+
// Reset token counters after successful compaction so the check
996+
// doesn't trigger again on the next iteration. Compact appends
997+
// items on success, so a changed message count indicates success.
998+
if len(sess.Messages) > messageCountBefore {
999+
sess.InputTokens = 0
1000+
sess.OutputTokens = 0
1001+
}
9931002
}
9941003
}
9951004

pkg/runtime/runtime_test.go

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"reflect"
9+
"strings"
910
"sync"
1011
"testing"
1112
"time"
@@ -75,6 +76,18 @@ func (m *mockStream) Recv() (chat.MessageStreamResponse, error) {
7576

7677
func (m *mockStream) Close() { m.closed = true }
7778

79+
// errorStream always returns an error on Recv, simulating a failed model call.
80+
type errorStream struct {
81+
err error
82+
closed bool
83+
}
84+
85+
func (e *errorStream) Recv() (chat.MessageStreamResponse, error) {
86+
return chat.MessageStreamResponse{}, e.err
87+
}
88+
89+
func (e *errorStream) Close() { e.closed = true }
90+
7891
type streamBuilder struct{ responses []chat.MessageStreamResponse }
7992

8093
func newStreamBuilder() *streamBuilder {
@@ -725,6 +738,205 @@ func TestCompaction(t *testing.T) {
725738
require.NotEqual(t, -1, compactionStartIdx, "expected a SessionCompaction start event")
726739
}
727740

741+
// capturingQueueProvider extends queueProvider to also capture the messages
742+
// passed to each CreateChatCompletionStream call.
743+
type capturingQueueProvider struct {
744+
queueProvider
745+
calls [][]chat.Message // messages sent on each call
746+
}
747+
748+
func (p *capturingQueueProvider) CreateChatCompletionStream(_ context.Context, msgs []chat.Message, _ []tools.Tool) (chat.MessageStream, error) {
749+
p.mu.Lock()
750+
defer p.mu.Unlock()
751+
p.calls = append(p.calls, msgs)
752+
if len(p.streams) == 0 {
753+
return &mockStream{}, nil
754+
}
755+
s := p.streams[0]
756+
p.streams = p.streams[1:]
757+
return s, nil
758+
}
759+
760+
func TestCompaction_ContinuationMessageSent(t *testing.T) {
761+
// After auto-compaction, the runtime must inject a continuation user
762+
// message so that the model receives at least one non-system message.
763+
// This prevents providers (e.g. Anthropic) from rejecting the request
764+
// with "messages: Field required".
765+
766+
// Stream 1: initial response that pushes usage above 90% of context.
767+
mainStream := newStreamBuilder().
768+
AddContent("Hello there").
769+
AddStopWithUsage(101, 0). // will exceed 90% of 100
770+
Build()
771+
772+
// Stream 2: summary generation (used by the compactor runtime).
773+
summaryStream := newStreamBuilder().
774+
AddContent("summary of conversation").
775+
AddStopWithUsage(1, 1).
776+
Build()
777+
778+
// Stream 3: the model call that happens after compaction.
779+
postCompactionStream := newStreamBuilder().
780+
AddContent("I'll continue.").
781+
AddStopWithUsage(5, 3).
782+
Build()
783+
784+
prov := &capturingQueueProvider{
785+
queueProvider: queueProvider{
786+
id: "test/mock-model",
787+
streams: []chat.MessageStream{mainStream, summaryStream, postCompactionStream},
788+
},
789+
}
790+
791+
root := agent.New("root", "You are a test agent", agent.WithModel(prov))
792+
tm := team.New(team.WithAgents(root))
793+
794+
rt, err := NewLocalRuntime(tm, WithSessionCompaction(true), WithModelStore(mockModelStoreWithLimit{limit: 100}))
795+
require.NoError(t, err)
796+
797+
// First RunStream: establishes baseline usage.
798+
sess := session.New(session.WithUserMessage("Start"))
799+
for range rt.RunStream(t.Context(), sess) {
800+
}
801+
802+
// Second RunStream: will trigger compaction because usage > 90%.
803+
sess.AddMessage(session.UserMessage("Again"))
804+
for range rt.RunStream(t.Context(), sess) {
805+
}
806+
807+
// The third model call (stream 3 = index 2 from the provider's perspective,
808+
// but calls[1] because first RunStream used calls[0]) should contain
809+
// a user message with the continuation prompt.
810+
require.GreaterOrEqual(t, len(prov.calls), 2, "expected at least 2 model calls in second RunStream")
811+
812+
// The last captured call is the post-compaction model invocation.
813+
lastCallMsgs := prov.calls[len(prov.calls)-1]
814+
815+
// Find the continuation user message.
816+
var foundContinuation bool
817+
for _, msg := range lastCallMsgs {
818+
if msg.Role == chat.MessageRoleUser && msg.Content == "The conversation was automatically compacted. Please continue where you left off." {
819+
foundContinuation = true
820+
break
821+
}
822+
}
823+
require.True(t, foundContinuation, "expected continuation user message after compaction; messages: %v", lastCallMsgs)
824+
825+
// Also check the summary is in the system messages.
826+
var foundSummary bool
827+
for _, msg := range lastCallMsgs {
828+
if msg.Role == chat.MessageRoleSystem && strings.Contains(msg.Content, "summary of conversation") {
829+
foundSummary = true
830+
break
831+
}
832+
}
833+
require.True(t, foundSummary, "expected session summary in system messages after compaction")
834+
}
835+
836+
func TestCompaction_TokenCountersResetAfterSuccess(t *testing.T) {
837+
// After successful compaction the token counters must be reset to zero
838+
// to prevent the compaction check from triggering again immediately.
839+
840+
mainStream := newStreamBuilder().
841+
AddContent("Hello").
842+
AddStopWithUsage(101, 0).
843+
Build()
844+
845+
summaryStream := newStreamBuilder().
846+
AddContent("summary").
847+
AddStopWithUsage(1, 1).
848+
Build()
849+
850+
// Post-compaction stream with LOW usage.
851+
postCompactionStream := newStreamBuilder().
852+
AddContent("Continuing").
853+
AddStopWithUsage(10, 5).
854+
Build()
855+
856+
prov := &queueProvider{
857+
id: "test/mock-model",
858+
streams: []chat.MessageStream{mainStream, summaryStream, postCompactionStream},
859+
}
860+
861+
root := agent.New("root", "You are a test agent", agent.WithModel(prov))
862+
tm := team.New(team.WithAgents(root))
863+
864+
rt, err := NewLocalRuntime(tm, WithSessionCompaction(true), WithModelStore(mockModelStoreWithLimit{limit: 100}))
865+
require.NoError(t, err)
866+
867+
sess := session.New(session.WithUserMessage("Start"))
868+
for range rt.RunStream(t.Context(), sess) {
869+
}
870+
871+
sess.AddMessage(session.UserMessage("Again"))
872+
873+
var events []Event
874+
for ev := range rt.RunStream(t.Context(), sess) {
875+
events = append(events, ev)
876+
}
877+
878+
// Count how many times compaction started — should be exactly once.
879+
var compactionCount int
880+
for _, ev := range events {
881+
if e, ok := ev.(*SessionCompactionEvent); ok && e.Status == "started" {
882+
compactionCount++
883+
}
884+
}
885+
require.Equal(t, 1, compactionCount, "compaction should trigger exactly once, not loop")
886+
887+
// Token counters should reflect the post-compaction model call, not the old values.
888+
assert.Equal(t, int64(10), sess.InputTokens, "InputTokens should be from post-compaction call")
889+
assert.Equal(t, int64(5), sess.OutputTokens, "OutputTokens should be from post-compaction call")
890+
}
891+
892+
func TestCompaction_FailedCompactionNoStrayMessage(t *testing.T) {
893+
// When compaction fails (summary runtime errors), no continuation
894+
// message should be added and the original messages should be preserved.
895+
896+
// Stream 1: initial response with high usage.
897+
mainStream := newStreamBuilder().
898+
AddContent("Hello").
899+
AddStopWithUsage(101, 0).
900+
Build()
901+
902+
// Stream 2: summary generation FAILS — the stream returns an error.
903+
failingStream := &errorStream{err: fmt.Errorf("simulated API error during summary")}
904+
905+
prov := &queueProvider{
906+
id: "test/mock-model",
907+
streams: []chat.MessageStream{mainStream, failingStream},
908+
}
909+
910+
root := agent.New("root", "You are a test agent", agent.WithModel(prov))
911+
tm := team.New(team.WithAgents(root))
912+
913+
rt, err := NewLocalRuntime(tm, WithSessionCompaction(true), WithModelStore(mockModelStoreWithLimit{limit: 100}))
914+
require.NoError(t, err)
915+
916+
sess := session.New(session.WithUserMessage("Start"))
917+
for range rt.RunStream(t.Context(), sess) {
918+
}
919+
920+
sess.AddMessage(session.UserMessage("Again"))
921+
922+
for range rt.RunStream(t.Context(), sess) {
923+
}
924+
925+
// No summary was appended, so no continuation message should exist.
926+
for _, item := range sess.Messages {
927+
if item.IsMessage() && item.Message.Message.Content == "The conversation was automatically compacted. Please continue where you left off." {
928+
t.Fatal("found stray continuation message after failed compaction")
929+
}
930+
}
931+
932+
// No summary items should exist.
933+
for _, item := range sess.Messages {
934+
if item.Summary != "" {
935+
t.Fatal("found summary item after failed compaction")
936+
}
937+
}
938+
}
939+
728940
func TestSessionWithoutUserMessage(t *testing.T) {
729941
stream := newStreamBuilder().AddContent("OK").AddStopWithUsage(1, 1).Build()
730942

pkg/runtime/session_compaction.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ func (c *sessionCompactor) Compact(ctx context.Context, sess *session.Session, a
8888
}
8989

9090
sess.Messages = append(sess.Messages, session.Item{Summary: summary})
91+
92+
// After compaction, the summary is the last item. GetMessages starts
93+
// collecting conversation messages after the last summary, so there
94+
// would be zero conversation messages. Providers (e.g. Anthropic)
95+
// reject requests with no non-system messages, so we add a
96+
// continuation message to bridge the gap.
97+
sess.AddMessage(session.ImplicitUserMessage("The conversation was automatically compacted. Please continue where you left off."))
98+
9199
_ = c.sessionStore.UpdateSession(ctx, sess)
92100

93101
slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(summary))

0 commit comments

Comments
 (0)