Skip to content

Commit 60cb3b1

Browse files
committed
feat(adk): define PopulateUserMessages utility for summariztion middleware
1 parent 8edb235 commit 60cb3b1

4 files changed

Lines changed: 458 additions & 145 deletions

File tree

adk/middlewares/summarization/finalizer_builder.go

Lines changed: 193 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,57 @@ import (
2828
"github.com/cloudwego/eino/schema"
2929
)
3030

31-
type FinalizerBuilder struct {
32-
handlers []FinalizeFunc
33-
custom FinalizeFunc
34-
errs []error
35-
}
36-
37-
// NewFinalizer creates a new FinalizerBuilder that builds a FinalizeFunc
38-
// by chaining handlers and an optional custom finalizer.
31+
// TypedFinalizerBuilder builds a TypedFinalizeFunc by chaining handlers
32+
// and an optional custom finalizer, generic over message type M.
33+
//
3934
// Handlers (e.g. PreserveSkills) transform the summary message sequentially,
4035
// and the custom finalizer (set via Custom) determines the final output messages.
4136
//
4237
// Example:
4338
//
4439
// finalizer, err := NewFinalizer().
45-
// PreserveSkills(&PreserveSkillsConfig{}).
46-
// Custom(func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
47-
// return []adk.Message{schema.SystemMessage("system prompt"), summary}, nil
48-
// }).
49-
// Build()
40+
// PreserveSkills(&PreserveSkillsConfig{}).
41+
// Custom(func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
42+
// return []adk.Message{schema.SystemMessage("system prompt"), summary}, nil
43+
// }).
44+
// Build()
5045
//
5146
// cfg := &Config{
52-
// Finalize: finalizer,
53-
// // ...
47+
// Finalize: finalizer,
48+
// // ...
5449
// }
50+
type TypedFinalizerBuilder[M adk.MessageType] struct {
51+
handlers []TypedFinalizeFunc[M]
52+
custom TypedFinalizeFunc[M]
53+
errs []error
54+
}
55+
56+
// FinalizerBuilder is a backward-compatible alias for TypedFinalizerBuilder
57+
// specialized with *schema.Message.
58+
type FinalizerBuilder = TypedFinalizerBuilder[*schema.Message]
59+
60+
// NewTypedFinalizer creates a new TypedFinalizerBuilder that builds a TypedFinalizeFunc
61+
// by chaining handlers and an optional custom finalizer.
62+
func NewTypedFinalizer[M adk.MessageType]() *TypedFinalizerBuilder[M] {
63+
return &TypedFinalizerBuilder[M]{}
64+
}
65+
66+
// NewFinalizer creates a new FinalizerBuilder that builds a FinalizeFunc
67+
// by chaining handlers and an optional custom finalizer.
5568
func NewFinalizer() *FinalizerBuilder {
5669
return &FinalizerBuilder{}
5770
}
5871

5972
// Custom sets a custom finalizer that determines the final output messages.
6073
// If called multiple times, the last custom finalizer takes effect.
61-
func (b *FinalizerBuilder) Custom(fn FinalizeFunc) *FinalizerBuilder {
74+
func (b *TypedFinalizerBuilder[M]) Custom(fn TypedFinalizeFunc[M]) *TypedFinalizerBuilder[M] {
6275
b.custom = fn
6376
return b
6477
}
6578

66-
func (b *FinalizerBuilder) Build() (FinalizeFunc, error) {
79+
// Build constructs the final TypedFinalizeFunc by chaining all registered handlers
80+
// and the optional custom finalizer.
81+
func (b *TypedFinalizerBuilder[M]) Build() (TypedFinalizeFunc[M], error) {
6782
if len(b.errs) > 0 {
6883
msgs := make([]string, len(b.errs))
6984
for i, e := range b.errs {
@@ -76,11 +91,11 @@ func (b *FinalizerBuilder) Build() (FinalizeFunc, error) {
7691
return nil, fmt.Errorf("at least one handler or custom finalizer is required")
7792
}
7893

79-
handlers := make([]FinalizeFunc, len(b.handlers))
94+
handlers := make([]TypedFinalizeFunc[M], len(b.handlers))
8095
copy(handlers, b.handlers)
8196
custom := b.custom
8297

83-
return func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
98+
return func(ctx context.Context, originalMessages []M, summary M) ([]M, error) {
8499
for _, fn := range handlers {
85100
result, err := fn(ctx, originalMessages, summary)
86101
if err != nil {
@@ -93,7 +108,7 @@ func (b *FinalizerBuilder) Build() (FinalizeFunc, error) {
93108
return custom(ctx, originalMessages, summary)
94109
}
95110

96-
return []adk.Message{summary}, nil
111+
return []M{summary}, nil
97112
}, nil
98113
}
99114

@@ -127,36 +142,108 @@ type PreserveSkillsConfig struct {
127142
// PreserveSkills extracts skill contents loaded by the ADK skill middleware
128143
// from the conversation history and prepends them to the summary message,
129144
// ensuring the agent retains skill knowledge after the context window is compacted.
130-
func (b *FinalizerBuilder) PreserveSkills(config *PreserveSkillsConfig) *FinalizerBuilder {
145+
func (b *TypedFinalizerBuilder[M]) PreserveSkills(config *PreserveSkillsConfig) *TypedFinalizerBuilder[M] {
131146
if err := config.check(); err != nil {
132147
b.errs = append(b.errs, fmt.Errorf("PreserveSkills: %w", err))
133148
return b
134149
}
135-
b.handlers = append(b.handlers, func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
136-
modelInput, _ := ctx.Value(ctxKeyModelInput{}).([]adk.Message)
137-
if len(modelInput) == 0 {
138-
panic("impossible: model input is empty")
150+
b.handlers = append(b.handlers, func(ctx context.Context, originalMessages []M, summary M) ([]M, error) {
151+
messages := originalMessages
152+
153+
modelInput, ok := ctx.Value(ctxKeyModelInput{}).([]M)
154+
if ok && len(modelInput) > 0 {
155+
messages = modelInput
156+
}
157+
158+
if len(messages) == 0 {
159+
return []M{summary}, nil
139160
}
140161

141-
skillText, err := buildPreservedSkillsText(ctx, modelInput, config)
162+
skillText, err := buildPreservedSkillsText(ctx, messages, config)
142163
if err != nil {
143164
return nil, err
144165
}
145166

146167
if skillText != "" {
147-
summary.UserInputMultiContent = append([]schema.MessageInputPart{
148-
{
149-
Type: schema.ChatMessagePartTypeText,
150-
Text: skillText,
151-
},
152-
}, summary.UserInputMultiContent...)
168+
summary = prependMsgTextContent(summary, skillText)
153169
}
154170

155-
return []adk.Message{summary}, nil
171+
return []M{summary}, nil
156172
})
157173
return b
158174
}
159175

176+
// PopulateUserMessagesConfig configures the PopulateUserMessages function.
177+
type PopulateUserMessagesConfig[M adk.MessageType] struct {
178+
// MaxTokens limits the maximum token count for preserved user messages.
179+
// Optional. Defaults to 30000.
180+
MaxTokens int
181+
182+
// Filter determines whether a specific user message should be preserved.
183+
// It is called for each user message. If it returns false, the message will not be preserved.
184+
// Optional.
185+
Filter TypedUserMessageFilterFunc[M]
186+
187+
// TokenCounter provides custom token counting.
188+
// Optional. Uses default estimator if not set.
189+
TokenCounter TypedTokenCounterFunc[M]
190+
}
191+
192+
// PopulateUserMessages is a convenience function that replaces the
193+
// <all_user_messages>...</all_user_messages> section in summaryText with
194+
// recent user messages from the given messages.
195+
func PopulateUserMessages[M adk.MessageType](ctx context.Context, messages []M, summaryText string, config PopulateUserMessagesConfig[M]) (string, error) {
196+
if config.MaxTokens < 0 {
197+
return "", fmt.Errorf("MaxTokens must be non-negative")
198+
}
199+
200+
const defaultMaxTokens = 30000
201+
202+
maxTokens := defaultMaxTokens
203+
if config.MaxTokens > 0 {
204+
maxTokens = config.MaxTokens
205+
}
206+
207+
userMsgs := msgsFromFirstUser(messages)
208+
209+
return populateUserMessages(ctx, &populateUserMessagesParams[M]{
210+
messages: userMsgs,
211+
summaryText: summaryText,
212+
maxTokens: maxTokens,
213+
filter: config.Filter,
214+
tokenCounter: config.TokenCounter,
215+
})
216+
}
217+
218+
func msgsFromFirstUser[M adk.MessageType](msgs []M) []M {
219+
for i, msg := range msgs {
220+
if isUserRole(msg) {
221+
return msgs[i:]
222+
}
223+
}
224+
return nil
225+
}
226+
227+
func prependMsgTextContent[M adk.MessageType](msg M, text string) M {
228+
switch m := any(msg).(type) {
229+
case *schema.Message:
230+
m.UserInputMultiContent = append([]schema.MessageInputPart{
231+
{
232+
Type: schema.ChatMessagePartTypeText,
233+
Text: text,
234+
},
235+
}, m.UserInputMultiContent...)
236+
return any(m).(M)
237+
case *schema.AgenticMessage:
238+
m.ContentBlocks = append([]*schema.ContentBlock{
239+
schema.NewContentBlock(&schema.UserInputText{Text: text}),
240+
}, m.ContentBlocks...)
241+
return any(m).(M)
242+
default:
243+
panic("unreachable")
244+
}
245+
}
246+
160247
func (c *PreserveSkillsConfig) check() error {
161248
if c == nil {
162249
return fmt.Errorf("PreserveSkillsConfig is required")
@@ -178,7 +265,76 @@ type skillInfo struct {
178265
Content string
179266
}
180267

181-
func buildPreservedSkillsText(_ context.Context, messages []adk.Message, config *PreserveSkillsConfig) (string, error) {
268+
func extractSkillInfos[M adk.MessageType](messages []M, skillTool string) ([]*skillInfo, error) {
269+
var skills []*skillInfo
270+
argsMap := make(map[string]string)
271+
272+
for _, msg := range messages {
273+
switch m := any(msg).(type) {
274+
case *schema.Message:
275+
if m.Role == schema.Assistant {
276+
for _, tc := range m.ToolCalls {
277+
if tc.Function.Name == skillTool {
278+
argsMap[tc.ID] = tc.Function.Arguments
279+
}
280+
}
281+
} else if m.Role == schema.Tool {
282+
arguments, ok := argsMap[m.ToolCallID]
283+
if !ok {
284+
continue
285+
}
286+
var arg struct {
287+
Skill string `json:"skill"`
288+
}
289+
if err := sonic.UnmarshalString(arguments, &arg); err != nil {
290+
return nil, fmt.Errorf("failed to parse skill arguments from tool call %s: %w", m.ToolCallID, err)
291+
}
292+
skills = append(skills, &skillInfo{
293+
Name: arg.Skill,
294+
Content: m.Content,
295+
})
296+
}
297+
298+
case *schema.AgenticMessage:
299+
for _, block := range m.ContentBlocks {
300+
if block == nil {
301+
continue
302+
}
303+
if block.Type == schema.ContentBlockTypeFunctionToolCall && block.FunctionToolCall != nil {
304+
if block.FunctionToolCall.Name == skillTool {
305+
argsMap[block.FunctionToolCall.CallID] = block.FunctionToolCall.Arguments
306+
}
307+
}
308+
if block.Type == schema.ContentBlockTypeFunctionToolResult && block.FunctionToolResult != nil {
309+
arguments, ok := argsMap[block.FunctionToolResult.CallID]
310+
if !ok {
311+
continue
312+
}
313+
var arg struct {
314+
Skill string `json:"skill"`
315+
}
316+
if err := sonic.UnmarshalString(arguments, &arg); err != nil {
317+
return nil, fmt.Errorf("failed to parse skill arguments from tool call %s: %w", block.FunctionToolResult.CallID, err)
318+
}
319+
var contentParts []string
320+
for _, cb := range block.FunctionToolResult.Content {
321+
if cb != nil && cb.Type == schema.FunctionToolResultContentBlockTypeText && cb.Text != nil {
322+
contentParts = append(contentParts, cb.Text.Text)
323+
}
324+
}
325+
skills = append(skills, &skillInfo{
326+
Name: arg.Skill,
327+
Content: strings.Join(contentParts, "\n"),
328+
})
329+
}
330+
}
331+
}
332+
}
333+
334+
return skills, nil
335+
}
336+
337+
func buildPreservedSkillsText[M adk.MessageType](_ context.Context, messages []M, config *PreserveSkillsConfig) (string, error) {
182338
const (
183339
defaultSkillTool = "skill"
184340
defaultMaxTokensPerSkill = 5000
@@ -212,36 +368,9 @@ func buildPreservedSkillsText(_ context.Context, messages []adk.Message, config
212368
skillsTokenBudget = *config.SkillsTokenBudget
213369
}
214370

215-
var skills []*skillInfo
216-
argsMap := make(map[string]string)
217-
218-
for _, msg := range messages {
219-
switch msg.Role {
220-
case schema.Assistant:
221-
for _, tc := range msg.ToolCalls {
222-
if tc.Function.Name == skillTool {
223-
argsMap[tc.ID] = tc.Function.Arguments
224-
}
225-
}
226-
227-
case schema.Tool:
228-
arguments, ok := argsMap[msg.ToolCallID]
229-
if !ok {
230-
continue
231-
}
232-
233-
var arg struct {
234-
Skill string `json:"skill"`
235-
}
236-
if err := sonic.UnmarshalString(arguments, &arg); err != nil {
237-
return "", fmt.Errorf("failed to parse skill arguments from tool call %s: %w", msg.ToolCallID, err)
238-
}
239-
240-
skills = append(skills, &skillInfo{
241-
Name: arg.Skill,
242-
Content: msg.Content,
243-
})
244-
}
371+
skills, err := extractSkillInfos(messages, skillTool)
372+
if err != nil {
373+
return "", err
245374
}
246375

247376
if len(skills) == 0 {

0 commit comments

Comments
 (0)