Skip to content

Commit d25ced3

Browse files
authored
feat(adk): PreserveSkillsConfig support MaxTokensPerSkill and SkillsTokenBudget (#967)
1 parent 7677537 commit d25ced3

4 files changed

Lines changed: 266 additions & 10 deletions

File tree

adk/middlewares/summarization/finalizer_builder.go

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"fmt"
2222
"strings"
23+
"unicode/utf8"
2324

2425
"github.com/bytedance/sonic"
2526

@@ -97,16 +98,30 @@ func (b *FinalizerBuilder) Build() (FinalizeFunc, error) {
9798
}
9899

99100
type PreserveSkillsConfig struct {
101+
// SkillToolName is the tool name used for loading skills.
102+
// Must match the tool name configured in the ADK skill middleware.
103+
// Optional. Defaults to "skill".
104+
SkillToolName string
105+
100106
// MaxSkills limits the maximum number of skills to preserve.
101107
// = 0 means do not preserve any skills (disabled).
102108
// > 0 means preserve up to this many most recent skills.
103109
// Optional. Defaults to 5.
104110
MaxSkills *int
105111

106-
// SkillToolName is the tool name used for loading skills.
107-
// Must match the tool name configured in the ADK skill middleware.
108-
// Optional. Defaults to "skill".
109-
SkillToolName string
112+
// MaxTokensPerSkill limits the maximum token count for a single preserved skill.
113+
// Skills exceeding this limit are truncated, with the truncated portion replaced
114+
// by a short marker text (e.g. "[... skill content truncated ...]").
115+
// Note: if this value is set smaller than the token count of the marker text itself,
116+
// the skill will contain only the marker text with no original content preserved.
117+
// Optional. Defaults to 5000.
118+
MaxTokensPerSkill *int
119+
120+
// SkillsTokenBudget limits the total token count for all preserved skills combined.
121+
// Skills are preserved from most recent to oldest; once the budget is exhausted,
122+
// remaining skills are dropped.
123+
// Optional. Defaults to 25000.
124+
SkillsTokenBudget *int
110125
}
111126

112127
// PreserveSkills extracts skill contents loaded by the ADK skill middleware
@@ -149,6 +164,12 @@ func (c *PreserveSkillsConfig) check() error {
149164
if c.MaxSkills != nil && *c.MaxSkills < 0 {
150165
return fmt.Errorf("MaxSkills must be non-negative")
151166
}
167+
if c.MaxTokensPerSkill != nil && *c.MaxTokensPerSkill < 0 {
168+
return fmt.Errorf("MaxTokensPerSkill must be non-negative")
169+
}
170+
if c.SkillsTokenBudget != nil && *c.SkillsTokenBudget < 0 {
171+
return fmt.Errorf("SkillsTokenBudget must be non-negative")
172+
}
152173
return nil
153174
}
154175

@@ -158,7 +179,11 @@ type skillInfo struct {
158179
}
159180

160181
func buildPreservedSkillsText(_ context.Context, messages []adk.Message, config *PreserveSkillsConfig) (string, error) {
161-
const defaultSkillTool = "skill"
182+
const (
183+
defaultSkillTool = "skill"
184+
defaultMaxTokensPerSkill = 5000
185+
defaultSkillsTokenBudget = 25000
186+
)
162187

163188
if config == nil {
164189
config = &PreserveSkillsConfig{}
@@ -177,6 +202,16 @@ func buildPreservedSkillsText(_ context.Context, messages []adk.Message, config
177202
skillTool = config.SkillToolName
178203
}
179204

205+
maxTokensPerSkill := defaultMaxTokensPerSkill
206+
if config.MaxTokensPerSkill != nil {
207+
maxTokensPerSkill = *config.MaxTokensPerSkill
208+
}
209+
210+
skillsTokenBudget := defaultSkillsTokenBudget
211+
if config.SkillsTokenBudget != nil {
212+
skillsTokenBudget = *config.SkillsTokenBudget
213+
}
214+
180215
var skills []*skillInfo
181216
argsMap := make(map[string]string)
182217

@@ -188,6 +223,7 @@ func buildPreservedSkillsText(_ context.Context, messages []adk.Message, config
188223
argsMap[tc.ID] = tc.Function.Arguments
189224
}
190225
}
226+
191227
case schema.Tool:
192228
arguments, ok := argsMap[msg.ToolCallID]
193229
if !ok {
@@ -231,8 +267,39 @@ func buildPreservedSkillsText(_ context.Context, messages []adk.Message, config
231267
skills = skills[len(skills)-maxSkills:]
232268
}
233269

270+
totalTokens := 0
271+
var budgetedSkills []*skillInfo
272+
for i := len(skills) - 1; i >= 0; i-- {
273+
skill := skills[i]
274+
tokens := estimateTokenCount(skill.Content)
275+
276+
if tokens > maxTokensPerSkill {
277+
skill = &skillInfo{
278+
Name: skill.Name,
279+
Content: truncateSkillContent(skill.Content, maxTokensPerSkill),
280+
}
281+
tokens = maxTokensPerSkill
282+
}
283+
284+
if totalTokens+tokens > skillsTokenBudget {
285+
break
286+
}
287+
288+
totalTokens += tokens
289+
budgetedSkills = append(budgetedSkills, skill)
290+
}
291+
292+
if len(budgetedSkills) == 0 {
293+
return "", nil
294+
}
295+
296+
// Reverse to restore chronological order.
297+
for i, j := 0, len(budgetedSkills)-1; i < j; i, j = i+1, j-1 {
298+
budgetedSkills[i], budgetedSkills[j] = budgetedSkills[j], budgetedSkills[i]
299+
}
300+
234301
var parts []string
235-
for _, skill := range skills {
302+
for _, skill := range budgetedSkills {
236303
parts = append(parts, fmt.Sprintf(skillSectionFormat, skill.Name, skill.Content))
237304
}
238305

@@ -242,3 +309,33 @@ func buildPreservedSkillsText(_ context.Context, messages []adk.Message, config
242309

243310
return skillsText, nil
244311
}
312+
313+
// truncateSkillContent truncates skill content to fit within maxTokens.
314+
// It keeps the first portion of the content and appends a truncation marker
315+
// (e.g. "[... skill content truncated ...]") to indicate the omission.
316+
// If maxTokens is smaller than the marker itself, only the marker is returned.
317+
func truncateSkillContent(content string, maxTokens int) string {
318+
if len(content) == 0 {
319+
return content
320+
}
321+
322+
if estimateTokenCount(content) <= maxTokens {
323+
return content
324+
}
325+
326+
marker := getSkillTruncationMarker()
327+
targetBytes := estimateTokenBytes(maxTokens) - len(marker)
328+
if targetBytes < 0 {
329+
targetBytes = 0
330+
}
331+
if targetBytes > len(content) {
332+
targetBytes = len(content)
333+
}
334+
335+
// Back up to a valid UTF-8 rune boundary.
336+
for targetBytes > 0 && targetBytes < len(content) && !utf8.RuneStart(content[targetBytes]) {
337+
targetBytes--
338+
}
339+
340+
return content[:targetBytes] + marker
341+
}

adk/middlewares/summarization/finalizer_builder_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ package summarization
1919
import (
2020
"context"
2121
"errors"
22+
"strings"
2223
"testing"
24+
"unicode/utf8"
2325

2426
"github.com/stretchr/testify/assert"
2527

@@ -212,6 +214,24 @@ func TestPreserveSkillsConfigCheck(t *testing.T) {
212214
err := c.check()
213215
assert.NoError(t, err)
214216
})
217+
218+
t.Run("negative max tokens per skill", func(t *testing.T) {
219+
c := &PreserveSkillsConfig{
220+
MaxTokensPerSkill: ptr(-1),
221+
}
222+
err := c.check()
223+
assert.Error(t, err)
224+
assert.Contains(t, err.Error(), "MaxTokensPerSkill must be non-negative")
225+
})
226+
227+
t.Run("negative skills token budget", func(t *testing.T) {
228+
c := &PreserveSkillsConfig{
229+
SkillsTokenBudget: ptr(-1),
230+
}
231+
err := c.check()
232+
assert.Error(t, err)
233+
assert.Contains(t, err.Error(), "SkillsTokenBudget must be non-negative")
234+
})
215235
}
216236

217237
func TestPreserveSkillsViaBuilder(t *testing.T) {
@@ -434,4 +454,128 @@ func TestBuildPreservedSkillsText(t *testing.T) {
434454
assert.NotContains(t, text, "skill2")
435455
assert.NotContains(t, text, "c2")
436456
})
457+
458+
t.Run("per skill token limit truncates large skills", func(t *testing.T) {
459+
// estimateTokenCount = (len+3)/4
460+
// "short" = 5 chars → 2 tokens
461+
// strings.Repeat("x", 100) = 100 chars → 25 tokens
462+
largeContent := strings.Repeat("x", 100)
463+
messages := []adk.Message{
464+
{
465+
Role: schema.Assistant,
466+
ToolCalls: []schema.ToolCall{
467+
{ID: "call_1", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "small"}`}},
468+
{ID: "call_2", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "large"}`}},
469+
},
470+
},
471+
{Role: schema.Tool, ToolCallID: "call_1", Content: "short"},
472+
{Role: schema.Tool, ToolCallID: "call_2", Content: largeContent},
473+
}
474+
475+
// MaxTokensPerSkill=10: "short"→2 tokens (ok), largeContent→25 tokens (truncated)
476+
text, err := buildPreservedSkillsText(ctx, messages, &PreserveSkillsConfig{
477+
MaxSkills: ptr(10),
478+
MaxTokensPerSkill: ptr(10),
479+
SkillToolName: "load_skill",
480+
})
481+
assert.NoError(t, err)
482+
// small skill preserved as-is
483+
assert.Contains(t, text, "small")
484+
assert.Contains(t, text, "short")
485+
// large skill is truncated, not dropped — name still present, full content gone
486+
assert.Contains(t, text, "large")
487+
assert.NotContains(t, text, largeContent)
488+
assert.Contains(t, text, "skill content truncated for compaction")
489+
})
490+
491+
t.Run("total token budget drops excess skills", func(t *testing.T) {
492+
// Each content is 40 chars → (40+3)/4 = 10 tokens
493+
content := strings.Repeat("a", 40)
494+
messages := []adk.Message{
495+
{
496+
Role: schema.Assistant,
497+
ToolCalls: []schema.ToolCall{
498+
{ID: "call_1", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "skill1"}`}},
499+
{ID: "call_2", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "skill2"}`}},
500+
{ID: "call_3", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "skill3"}`}},
501+
},
502+
},
503+
{Role: schema.Tool, ToolCallID: "call_1", Content: content},
504+
{Role: schema.Tool, ToolCallID: "call_2", Content: content},
505+
{Role: schema.Tool, ToolCallID: "call_3", Content: content},
506+
}
507+
508+
// Budget=15: skill3=10 tokens fits, skill2=10 tokens → 10+10=20 > 15, stop.
509+
text, err := buildPreservedSkillsText(ctx, messages, &PreserveSkillsConfig{
510+
MaxSkills: ptr(10),
511+
SkillsTokenBudget: ptr(15),
512+
SkillToolName: "load_skill",
513+
})
514+
assert.NoError(t, err)
515+
assert.Contains(t, text, "skill3")
516+
assert.NotContains(t, text, "skill1")
517+
assert.NotContains(t, text, "skill2")
518+
})
519+
520+
t.Run("token budget and per-skill limit combined", func(t *testing.T) {
521+
// s1: 16 chars → 4 tokens
522+
// s2: 200 chars → 50 tokens (exceeds per-skill limit of 20, gets truncated)
523+
// s3: 24 chars → 6 tokens
524+
// s4: 24 chars → 6 tokens
525+
messages := []adk.Message{
526+
{
527+
Role: schema.Assistant,
528+
ToolCalls: []schema.ToolCall{
529+
{ID: "call_1", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "s1"}`}},
530+
{ID: "call_2", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "s2"}`}},
531+
{ID: "call_3", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "s3"}`}},
532+
{ID: "call_4", Function: schema.FunctionCall{Name: "load_skill", Arguments: `{"skill": "s4"}`}},
533+
},
534+
},
535+
{Role: schema.Tool, ToolCallID: "call_1", Content: strings.Repeat("a", 16)},
536+
{Role: schema.Tool, ToolCallID: "call_2", Content: strings.Repeat("b", 200)},
537+
{Role: schema.Tool, ToolCallID: "call_3", Content: strings.Repeat("c", 24)},
538+
{Role: schema.Tool, ToolCallID: "call_4", Content: strings.Repeat("d", 24)},
539+
}
540+
541+
// Per-skill limit: 20 (s2 with 50 tokens is truncated to 20)
542+
// Budget: 30 (from most recent: s4=6, s3=6, s2=20, total=32 > 30, so s2 cannot fit)
543+
// Result: s4 and s3 preserved
544+
text, err := buildPreservedSkillsText(ctx, messages, &PreserveSkillsConfig{
545+
MaxSkills: ptr(10),
546+
MaxTokensPerSkill: ptr(20),
547+
SkillsTokenBudget: ptr(30),
548+
SkillToolName: "load_skill",
549+
})
550+
assert.NoError(t, err)
551+
assert.Contains(t, text, "s3")
552+
assert.Contains(t, text, "s4")
553+
assert.NotContains(t, text, "\"s1\"")
554+
assert.NotContains(t, text, "\"s2\"")
555+
})
556+
557+
t.Run("truncated skill content preserves only prefix", func(t *testing.T) {
558+
// Use a long content and generous maxTokens so the prefix is clearly visible.
559+
content := strings.Repeat("abcdefghij", 100) // 1000 bytes → 250 tokens
560+
// maxTokens=125 → targetBytes = 500, minus ~101 marker bytes → ~399 prefix bytes
561+
truncated := truncateSkillContent(content, 125)
562+
assert.True(t, strings.HasPrefix(truncated, "abcdefghij")) // prefix preserved
563+
assert.Contains(t, truncated, "skill content truncated for compaction")
564+
assert.NotEqual(t, content, truncated)
565+
// Ends with marker, not with original content suffix
566+
assert.True(t, strings.HasSuffix(truncated, "]"))
567+
// No suffix from original content
568+
assert.False(t, strings.HasSuffix(truncated, "abcdefghij]"))
569+
})
570+
571+
t.Run("truncated multibyte content does not produce invalid utf8", func(t *testing.T) {
572+
// Each Chinese char is 3 bytes. 334 chars = 1002 bytes → 251 tokens
573+
content := strings.Repeat("中", 334)
574+
// maxTokens=125 → targetBytes=500, minus marker ~101 bytes → ~399 bytes
575+
// 399 / 3 = 133 full Chinese chars, no partial rune
576+
truncated := truncateSkillContent(content, 125)
577+
assert.True(t, utf8.ValidString(truncated))
578+
assert.True(t, strings.HasPrefix(truncated, "中中中"))
579+
assert.Contains(t, truncated, "skill content truncated for compaction")
580+
})
437581
}

adk/middlewares/summarization/prompt.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,14 @@ func getSkillPreamble() string {
327327
Chinese: skillPreambleZh,
328328
})
329329
}
330+
331+
const skillTruncationMarker = "\n\n[... skill content truncated for compaction; use Read on the skill path if you need the full text]"
332+
333+
const skillTruncationMarkerZh = "\n\n[... skill 内容已在压缩时截断,如需完整内容请通过 Read 读取 skill 对应的文件路径]"
334+
335+
func getSkillTruncationMarker() string {
336+
return internal.SelectPrompt(internal.I18nPrompts{
337+
English: skillTruncationMarker,
338+
Chinese: skillTruncationMarkerZh,
339+
})
340+
}

adk/middlewares/summarization/summarization.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ func SummarizeMessages(ctx context.Context, cfg *Config, messages []adk.Message)
296296
return nil, err
297297
}
298298

299-
ctx = context.WithValue(ctx, ctxKeyModelInput{}, modelInput)
299+
finalizeCtx := context.WithValue(ctx, ctxKeyModelInput{}, modelInput)
300300

301-
_, finalMsgs, err := m.finalizeSummary(ctx, messages, rawSummary)
301+
_, finalMsgs, err := m.finalizeSummary(finalizeCtx, messages, rawSummary)
302302
if err != nil {
303303
return nil, err
304304
}
@@ -355,10 +355,10 @@ func (m *middleware) BeforeModelRewriteState(ctx context.Context, state *adk.Cha
355355
return nil, nil, err
356356
}
357357

358-
ctx = context.WithValue(ctx, ctxKeyModelInput{}, modelInput)
358+
finalizeCtx := context.WithValue(ctx, ctxKeyModelInput{}, modelInput)
359359

360360
var finalMsgs []adk.Message
361-
ctx, finalMsgs, err = m.finalizeSummary(ctx, beforeState.Messages, rawSummary)
361+
_, finalMsgs, err = m.finalizeSummary(finalizeCtx, beforeState.Messages, rawSummary)
362362
if err != nil {
363363
return nil, nil, err
364364
}
@@ -503,6 +503,10 @@ func estimateTokenCount(text string) int {
503503
return (len(text) + 3) / 4
504504
}
505505

506+
func estimateTokenBytes(tokens int) int {
507+
return tokens * 4
508+
}
509+
506510
func (m *middleware) summarize(ctx context.Context, originalMsgs []adk.Message) (adk.Message, []adk.Message, error) {
507511
_, contextMsgs := m.splitSystemAndContextMsgs(originalMsgs)
508512

0 commit comments

Comments
 (0)