diff --git a/pkg/summarize/hierarchy.go b/pkg/summarize/hierarchy.go new file mode 100644 index 0000000..b7af435 --- /dev/null +++ b/pkg/summarize/hierarchy.go @@ -0,0 +1,367 @@ +package summarize + +import ( + "context" + "fmt" + "strings" + "time" + "unicode" +) + +// HierarchicalSummarizer implements Summarizer using rule-based compression. +// It does not require an LLM — compression is performed locally using +// extractive techniques (sentence selection, keyword extraction). +// +// For LLM-backed summarization, wrap this with an LLMSummarizer that +// overrides the compress method. +type HierarchicalSummarizer struct{} + +// NewHierarchicalSummarizer creates a new summarizer. +func NewHierarchicalSummarizer() *HierarchicalSummarizer { + return &HierarchicalSummarizer{} +} + +// Summarize compresses turns to fit within opts.MaxTokens. +// Turns are processed oldest-first; recent turns and high-importance turns +// are preserved at full fidelity. +func (s *HierarchicalSummarizer) Summarize( + ctx context.Context, + turns []Turn, + opts SummarizeOptions, +) ([]Turn, SummarizeStats, error) { + start := time.Now() + + if opts.PreserveRecent < 0 { + opts.PreserveRecent = 10 + } + if opts.ImportanceThreshold <= 0 { + opts.ImportanceThreshold = 0.7 + } + if len(opts.AgeLevels) == 0 { + opts.AgeLevels = DefaultOptions().AgeLevels + } + + // Score importance for turns that don't have it set. + ScoreTurns(turns) + + // Count input tokens. + inputTokens := 0 + for i := range turns { + turns[i].TokenCount = estimateTokens(turns[i].Content) + inputTokens += turns[i].TokenCount + } + + stats := SummarizeStats{ + InputTurns: len(turns), + InputTokens: inputTokens, + } + + // Determine which turns to compress. + now := time.Now() + result := make([]Turn, len(turns)) + copy(result, turns) + + recentCutoff := len(result) - opts.PreserveRecent + if recentCutoff < 0 { + recentCutoff = 0 + } + + for i := range result { + t := &result[i] + + // Always preserve recent turns (only when PreserveRecent > 0). + if opts.PreserveRecent > 0 && i >= recentCutoff { + stats.PreservedTurns++ + continue + } + + // Preserve high-importance turns at LevelFull or LevelParagraph. + maxLevel := s.maxLevelForAge(now.Sub(t.Timestamp), opts.AgeLevels) + if t.Importance >= opts.ImportanceThreshold && maxLevel > LevelParagraph { + maxLevel = LevelParagraph + } + + if maxLevel <= t.Level { + // Already at or beyond target level. + stats.PreservedTurns++ + continue + } + + // Compress to target level. + if err := s.compressTo(t, maxLevel); err != nil { + return nil, stats, fmt.Errorf("compress turn %s: %w", t.ID, err) + } + t.TokenCount = estimateTokens(t.Content) + stats.CompressedTurns++ + } + + // If MaxTokens is set and we're still over budget, do a second pass + // compressing more aggressively from oldest to newest. + if opts.MaxTokens > 0 { + result = s.enforceTokenBudget(result, opts, recentCutoff) + } + + // Compute output stats. + outputTokens := 0 + for _, t := range result { + outputTokens += t.TokenCount + } + stats.OutputTurns = len(result) + stats.OutputTokens = outputTokens + if stats.InputTokens > 0 { + stats.ReductionPct = float64(stats.InputTokens-stats.OutputTokens) / float64(stats.InputTokens) * 100 + } + stats.Latency = time.Since(start) + + return result, stats, nil +} + +// enforceTokenBudget does a second compression pass when still over budget. +// It progressively compresses oldest turns through all levels, including +// eviction (dropping turns entirely) as a last resort. +func (s *HierarchicalSummarizer) enforceTokenBudget( + turns []Turn, + opts SummarizeOptions, + recentCutoff int, +) []Turn { + total := 0 + for _, t := range turns { + total += t.TokenCount + } + if total <= opts.MaxTokens { + return turns + } + + // Compress oldest non-recent turns progressively through all levels. + for level := LevelParagraph; level <= LevelEvicted && total > opts.MaxTokens; level++ { + for i := range turns { + if opts.PreserveRecent > 0 && i >= recentCutoff { + break + } + t := &turns[i] + if t.Level >= level { + continue + } + if t.Importance >= opts.ImportanceThreshold && level > LevelParagraph { + continue + } + before := t.TokenCount + if level == LevelEvicted { + t.Level = LevelEvicted + t.Content = "" + t.TokenCount = 0 + } else { + _ = s.compressTo(t, level) + t.TokenCount = estimateTokens(t.Content) + } + total -= before - t.TokenCount + if total <= opts.MaxTokens { + break + } + } + } + + // Remove evicted turns from the slice. + out := turns[:0] + for _, t := range turns { + if t.Level != LevelEvicted { + out = append(out, t) + } + } + return out +} + +// maxLevelForAge returns the maximum compression level for a given age. +func (s *HierarchicalSummarizer) maxLevelForAge(age time.Duration, levels []AgeLevel) Level { + max := LevelFull + for _, al := range levels { + if age >= al.After && al.MaxLevel > max { + max = al.MaxLevel + } + } + return max +} + +// compressTo compresses a turn to the target level in-place. +// The original content is preserved in Turn.Original on first compression. +func (s *HierarchicalSummarizer) compressTo(t *Turn, target Level) error { + if t.Original == "" { + t.Original = t.Content + } + + switch target { + case LevelParagraph: + t.Content = extractParagraphSummary(t.Original) + case LevelSentence: + t.Content = extractSentenceSummary(t.Original) + case LevelKeywords: + t.Content = extractKeywordSummary(t.Original) + } + t.Level = target + return nil +} + +// extractParagraphSummary keeps the first paragraph and any code blocks. +func extractParagraphSummary(text string) string { + lines := strings.Split(text, "\n") + var out []string + inCode := false + paragraphDone := false + + for _, line := range lines { + if strings.HasPrefix(line, "```") { + inCode = !inCode + out = append(out, line) + continue + } + if inCode { + out = append(out, line) + continue + } + if !paragraphDone { + out = append(out, line) + if line == "" && len(out) > 1 { + paragraphDone = true + } + } + } + result := strings.TrimSpace(strings.Join(out, "\n")) + if result == "" { + return truncate(text, 300) + } + return result +} + +// extractSentenceSummary returns the first 1–2 sentences. +func extractSentenceSummary(text string) string { + // Strip code blocks first. + text = stripCodeBlocks(text) + sentences := splitSentences(text) + if len(sentences) == 0 { + return truncate(text, 150) + } + if len(sentences) == 1 { + return sentences[0] + } + return sentences[0] + " " + sentences[1] +} + +// extractKeywordSummary extracts the most significant words. +func extractKeywordSummary(text string) string { + text = stripCodeBlocks(text) + words := strings.Fields(text) + var keywords []string + seen := map[string]bool{} + for _, w := range words { + w = strings.Trim(w, `.,;:!?"'()[]{}`) + lower := strings.ToLower(w) + if len(w) < 4 || isStopWord(lower) || seen[lower] { + continue + } + seen[lower] = true + keywords = append(keywords, w) + if len(keywords) >= 12 { + break + } + } + return strings.Join(keywords, ", ") +} + +func stripCodeBlocks(text string) string { + var out strings.Builder + inCode := false + for _, line := range strings.Split(text, "\n") { + if strings.HasPrefix(line, "```") { + inCode = !inCode + continue + } + if !inCode { + out.WriteString(line) + out.WriteByte('\n') + } + } + return out.String() +} + +func splitSentences(text string) []string { + var sentences []string + var cur strings.Builder + for _, r := range text { + cur.WriteRune(r) + if r == '.' || r == '!' || r == '?' { + s := strings.TrimSpace(cur.String()) + if s != "" { + sentences = append(sentences, s) + } + cur.Reset() + } + } + if s := strings.TrimSpace(cur.String()); s != "" { + sentences = append(sentences, s) + } + return sentences +} + +func truncate(s string, maxRunes int) string { + runes := []rune(s) + if len(runes) <= maxRunes { + return s + } + return string(runes[:maxRunes]) + "…" +} + +func isStopWord(w string) bool { + return stopWords[w] +} + +var stopWords = func() map[string]bool { + words := []string{ + "the", "and", "for", "that", "this", "with", "from", "have", + "will", "been", "were", "they", "their", "there", "when", + "what", "which", "would", "could", "should", "about", "into", + "more", "also", "some", "than", "then", "just", "like", + } + m := map[string]bool{} + for _, w := range words { + m[w] = true + } + return m +}() + +// DetectTurns segments a flat message list into Turn structs, assigning +// timestamps based on index when real timestamps are unavailable. +func DetectTurns(messages []struct { + Role string + Content string +}) []Turn { + now := time.Now() + turns := make([]Turn, len(messages)) + for i, m := range messages { + turns[i] = Turn{ + ID: fmt.Sprintf("turn-%d", i), + Role: m.Role, + Content: m.Content, + Original: m.Content, + Timestamp: now.Add(-time.Duration(len(messages)-i) * time.Minute), + Level: LevelFull, + TokenCount: estimateTokens(m.Content), + } + } + return turns +} + +// TotalTokens returns the sum of token counts across all turns. +func TotalTokens(turns []Turn) int { + total := 0 + for _, t := range turns { + total += t.TokenCount + } + return total +} + +// isLetter is used by keyword extraction. +func isLetter(r rune) bool { + return unicode.IsLetter(r) +} + +var _ = isLetter // suppress unused warning diff --git a/pkg/summarize/importance.go b/pkg/summarize/importance.go new file mode 100644 index 0000000..912d559 --- /dev/null +++ b/pkg/summarize/importance.go @@ -0,0 +1,98 @@ +package summarize + +import ( + "strings" + "unicode" +) + +// ScoreImportance returns an importance score (0–1) for a turn based on +// heuristic signals. Higher scores mean the turn should resist compression. +// +// Signals: +// - Contains a code block (```) → +0.4 +// - Contains an error keyword → +0.3 +// - Contains a decision keyword → +0.2 +// - System role → always 1.0 +// - Tool role → +0.2 +// - Short content (< 50 chars) → −0.1 +func ScoreImportance(t Turn) float64 { + if t.Role == "system" { + return 1.0 + } + + score := 0.5 // baseline + lower := strings.ToLower(t.Content) + + // Code blocks are high-value. + if strings.Contains(t.Content, "```") || strings.Contains(t.Content, "\t") { + score += 0.4 + } + + // Error signals. + for _, kw := range errorKeywords { + if strings.Contains(lower, kw) { + score += 0.3 + break + } + } + + // Decision / conclusion signals. + for _, kw := range decisionKeywords { + if strings.Contains(lower, kw) { + score += 0.2 + break + } + } + + // Tool calls carry structured data. + if t.Role == "tool" { + score += 0.2 + } + + // Very short turns are usually low-value. + if len([]rune(t.Content)) < 50 { + score -= 0.1 + } + + // Clamp to [0, 1]. + if score > 1.0 { + score = 1.0 + } + if score < 0 { + score = 0 + } + return score +} + +// ScoreTurns sets Importance on each turn in-place. +func ScoreTurns(turns []Turn) { + for i := range turns { + if turns[i].Importance == 0 { + turns[i].Importance = ScoreImportance(turns[i]) + } + } +} + +// estimateTokens approximates token count using the 4-chars-per-token heuristic. +func estimateTokens(s string) int { + // Count printable runes only. + n := 0 + for _, r := range s { + if !unicode.IsSpace(r) { + n++ + } + } + return (n + 3) / 4 +} + +var errorKeywords = []string{ + "error", "exception", "panic", "fatal", "failed", "failure", + "crash", "bug", "traceback", "stack trace", "nil pointer", + "segfault", "timeout", "deadlock", +} + +var decisionKeywords = []string{ + "decided", "decision", "conclusion", "therefore", "we will", + "we should", "let's use", "going with", "chosen", "agreed", + "final answer", "solution is", "approach is", +} diff --git a/pkg/summarize/summarize.go b/pkg/summarize/summarize.go new file mode 100644 index 0000000..c5e8315 --- /dev/null +++ b/pkg/summarize/summarize.go @@ -0,0 +1,91 @@ +// Package summarize provides hierarchical multi-level summarization for +// long conversation sessions. Turns are compressed progressively as they age: +// +// Level 0: full content (recent) +// Level 1: paragraph summary (medium age) +// Level 2: single-sentence summary (old) +// Level 3: keywords only (very old) +package summarize + +import ( + "context" + "time" +) + +// Level represents a compression level for a conversation turn. +type Level int + +const ( + LevelFull Level = 0 // original content + LevelParagraph Level = 1 // paragraph summary + LevelSentence Level = 2 // single-sentence summary + LevelKeywords Level = 3 // keywords only + LevelEvicted Level = 4 // dropped entirely (zero tokens) +) + +// Turn represents a single conversation turn. +type Turn struct { + ID string + Role string // "user", "assistant", "tool", "system" + Content string + Original string // preserved original for reversibility + Timestamp time.Time + Level Level + Importance float64 // 0–1; high-importance turns resist compression + TokenCount int +} + +// SummarizeOptions configures a summarization pass. +type SummarizeOptions struct { + // MaxTokens is the target total token budget. 0 = no limit. + MaxTokens int + + // PreserveRecent keeps the N most recent turns at LevelFull regardless + // of age or token pressure. Default: 10. + PreserveRecent int + + // ImportanceThreshold: turns with Importance >= this value are never + // compressed beyond LevelParagraph. Default: 0.7. + ImportanceThreshold float64 + + // AgeLevels maps turn age to the maximum compression level allowed. + // Turns older than AgeLevels[i].After are compressed to AgeLevels[i].MaxLevel. + AgeLevels []AgeLevel +} + +// AgeLevel maps a minimum age to a maximum compression level. +type AgeLevel struct { + After time.Duration + MaxLevel Level +} + +// DefaultOptions returns sensible defaults. +func DefaultOptions() SummarizeOptions { + return SummarizeOptions{ + MaxTokens: 0, + PreserveRecent: 10, + ImportanceThreshold: 0.7, + AgeLevels: []AgeLevel{ + {After: 30 * time.Minute, MaxLevel: LevelParagraph}, + {After: 2 * time.Hour, MaxLevel: LevelSentence}, + {After: 24 * time.Hour, MaxLevel: LevelKeywords}, + }, + } +} + +// SummarizeStats reports what happened during a summarization pass. +type SummarizeStats struct { + InputTurns int + OutputTurns int + InputTokens int + OutputTokens int + CompressedTurns int + PreservedTurns int + ReductionPct float64 + Latency time.Duration +} + +// Summarizer compresses conversation turns to fit within a token budget. +type Summarizer interface { + Summarize(ctx context.Context, turns []Turn, opts SummarizeOptions) ([]Turn, SummarizeStats, error) +} diff --git a/pkg/summarize/summarize_test.go b/pkg/summarize/summarize_test.go new file mode 100644 index 0000000..a10b8f8 --- /dev/null +++ b/pkg/summarize/summarize_test.go @@ -0,0 +1,217 @@ +package summarize + +import ( + "context" + "strings" + "testing" + "time" +) + +func makeTurn(id, role, content string, age time.Duration, importance float64) Turn { + return Turn{ + ID: id, + Role: role, + Content: content, + Original: content, + Timestamp: time.Now().Add(-age), + Level: LevelFull, + Importance: importance, + TokenCount: estimateTokens(content), + } +} + +func TestHierarchicalSummarizer_PreservesRecentTurns(t *testing.T) { + s := NewHierarchicalSummarizer() + ctx := context.Background() + + turns := []Turn{ + makeTurn("1", "user", "Old message from yesterday", 25*time.Hour, 0.3), + makeTurn("2", "assistant", "Old reply from yesterday", 25*time.Hour, 0.3), + makeTurn("3", "user", "Recent message", 1*time.Minute, 0.5), + makeTurn("4", "assistant", "Recent reply", 1*time.Minute, 0.5), + } + + opts := DefaultOptions() + opts.PreserveRecent = 2 + + result, stats, err := s.Summarize(ctx, turns, opts) + if err != nil { + t.Fatalf("Summarize: %v", err) + } + + // Recent turns (3, 4) must be at LevelFull. + if result[2].Level != LevelFull { + t.Errorf("turn 3: expected LevelFull, got %d", result[2].Level) + } + if result[3].Level != LevelFull { + t.Errorf("turn 4: expected LevelFull, got %d", result[3].Level) + } + + // Old turns should be compressed. + if result[0].Level == LevelFull { + t.Errorf("turn 1: expected compression, still at LevelFull") + } + + if stats.CompressedTurns == 0 { + t.Error("expected at least one compressed turn") + } +} + +func TestHierarchicalSummarizer_HighImportanceResistsCompression(t *testing.T) { + s := NewHierarchicalSummarizer() + ctx := context.Background() + + turns := []Turn{ + makeTurn("1", "assistant", "We decided to use PostgreSQL for the database. This is the final architecture decision.", 25*time.Hour, 0.9), + makeTurn("2", "user", "Some old low-importance message", 25*time.Hour, 0.2), + } + + opts := DefaultOptions() + opts.PreserveRecent = 0 + opts.ImportanceThreshold = 0.7 + + result, _, err := s.Summarize(ctx, turns, opts) + if err != nil { + t.Fatalf("Summarize: %v", err) + } + + // High-importance turn should not exceed LevelParagraph. + if result[0].Level > LevelParagraph { + t.Errorf("high-importance turn compressed beyond LevelParagraph: level=%d", result[0].Level) + } +} + +func TestHierarchicalSummarizer_TokenBudget(t *testing.T) { + s := NewHierarchicalSummarizer() + ctx := context.Background() + + // Create many old turns that exceed a small budget. + var turns []Turn + for i := 0; i < 20; i++ { + turns = append(turns, makeTurn( + "t", "user", + strings.Repeat("This is a long message with lots of content. ", 10), + 25*time.Hour, 0.3, + )) + } + + opts := DefaultOptions() + opts.MaxTokens = 200 + opts.PreserveRecent = 2 + + result, stats, err := s.Summarize(ctx, turns, opts) + if err != nil { + t.Fatalf("Summarize: %v", err) + } + + if stats.OutputTokens > opts.MaxTokens { + t.Errorf("output tokens %d exceeds budget %d", stats.OutputTokens, opts.MaxTokens) + } + _ = result +} + +func TestHierarchicalSummarizer_SystemRolePreserved(t *testing.T) { + s := NewHierarchicalSummarizer() + ctx := context.Background() + + turns := []Turn{ + makeTurn("sys", "system", "You are a helpful assistant.", 48*time.Hour, 0), + makeTurn("old", "user", "Old message", 48*time.Hour, 0.2), + } + + opts := DefaultOptions() + opts.PreserveRecent = 0 + + result, _, err := s.Summarize(ctx, turns, opts) + if err != nil { + t.Fatalf("Summarize: %v", err) + } + + // System turn has importance=1.0 after scoring, should not exceed LevelParagraph. + if result[0].Level > LevelParagraph { + t.Errorf("system turn compressed beyond LevelParagraph: level=%d", result[0].Level) + } +} + +func TestScoreImportance(t *testing.T) { + tests := []struct { + turn Turn + wantMin float64 + }{ + {Turn{Role: "system", Content: "You are helpful."}, 1.0}, + {Turn{Role: "assistant", Content: "```go\nfunc main() {}\n```"}, 0.8}, + {Turn{Role: "user", Content: "error: nil pointer dereference"}, 0.7}, + {Turn{Role: "user", Content: "ok"}, 0.0}, + } + + for _, tt := range tests { + score := ScoreImportance(tt.turn) + if score < tt.wantMin { + t.Errorf("role=%s content=%q: score %.2f < min %.2f", + tt.turn.Role, tt.turn.Content, score, tt.wantMin) + } + } +} + +func TestExtractParagraphSummary(t *testing.T) { + text := "First paragraph.\n\nSecond paragraph that should be excluded." + result := extractParagraphSummary(text) + if strings.Contains(result, "Second paragraph") { + t.Errorf("paragraph summary should not include second paragraph: %q", result) + } +} + +func TestExtractSentenceSummary(t *testing.T) { + text := "This is the first sentence. This is the second. This is the third." + result := extractSentenceSummary(text) + if strings.Contains(result, "third") { + t.Errorf("sentence summary should not include third sentence: %q", result) + } +} + +func TestExtractKeywordSummary(t *testing.T) { + text := "The authentication service uses JWT tokens with RS256 signing algorithm for security." + result := extractKeywordSummary(text) + if result == "" { + t.Error("expected non-empty keyword summary") + } + // Should not contain stop words. + for _, stop := range []string{"the", "with", "for"} { + if strings.Contains(strings.ToLower(result), " "+stop+" ") { + t.Errorf("keyword summary contains stop word %q: %q", stop, result) + } + } +} + +func TestTotalTokens(t *testing.T) { + turns := []Turn{ + {TokenCount: 10}, + {TokenCount: 20}, + {TokenCount: 30}, + } + if TotalTokens(turns) != 60 { + t.Errorf("expected 60, got %d", TotalTokens(turns)) + } +} + +func TestSummarizeStats_ReductionPct(t *testing.T) { + s := NewHierarchicalSummarizer() + ctx := context.Background() + + longContent := strings.Repeat("This is a detailed message about the system architecture. ", 20) + turns := []Turn{ + makeTurn("1", "user", longContent, 25*time.Hour, 0.2), + makeTurn("2", "assistant", longContent, 25*time.Hour, 0.2), + } + + opts := DefaultOptions() + opts.PreserveRecent = 0 + + _, stats, err := s.Summarize(ctx, turns, opts) + if err != nil { + t.Fatalf("Summarize: %v", err) + } + if stats.ReductionPct <= 0 { + t.Errorf("expected positive reduction, got %.1f%%", stats.ReductionPct) + } +}