|
| 1 | +package session |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "database/sql" |
| 6 | + "fmt" |
| 7 | + "sort" |
| 8 | + "time" |
| 9 | +) |
| 10 | + |
| 11 | +// minCacheableTokens is Anthropic's minimum prefix size to qualify for prompt |
| 12 | +// caching (Claude 3.5+). |
| 13 | +const minCacheableTokens = 1024 |
| 14 | + |
| 15 | +// maxCacheMarkers is the maximum number of simultaneous cache_control markers |
| 16 | +// Anthropic allows per request. |
| 17 | +const maxCacheMarkers = 4 |
| 18 | + |
| 19 | +// CacheBoundaryConfig controls how the boundary manager classifies stability |
| 20 | +// and places cache_control markers. |
| 21 | +type CacheBoundaryConfig struct { |
| 22 | + // Enabled turns the boundary manager on or off. Default: true. |
| 23 | + Enabled bool |
| 24 | + |
| 25 | + // MinStableTurns is the number of consecutive pushes an entry must survive |
| 26 | + // unmodified before it is considered stable. Default: 2. |
| 27 | + MinStableTurns int |
| 28 | + |
| 29 | + // MinPrefixTokens is the minimum combined token count required before any |
| 30 | + // marker is placed. Matches Anthropic's 1024-token minimum. Default: 1024. |
| 31 | + MinPrefixTokens int |
| 32 | + |
| 33 | + // MaxMarkers is the maximum number of cache_control markers to place. |
| 34 | + // Anthropic allows up to 4. Default: 4. |
| 35 | + MaxMarkers int |
| 36 | +} |
| 37 | + |
| 38 | +// DefaultCacheBoundaryConfig returns sensible defaults. |
| 39 | +func DefaultCacheBoundaryConfig() CacheBoundaryConfig { |
| 40 | + return CacheBoundaryConfig{ |
| 41 | + Enabled: true, |
| 42 | + MinStableTurns: 2, |
| 43 | + MinPrefixTokens: minCacheableTokens, |
| 44 | + MaxMarkers: maxCacheMarkers, |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +// CacheBoundaryMarker describes a single cache_control placement. |
| 49 | +type CacheBoundaryMarker struct { |
| 50 | + // EntryID is the session entry that should carry the marker. |
| 51 | + EntryID string |
| 52 | + |
| 53 | + // TokensUpToHere is the cumulative token count of all entries up to and |
| 54 | + // including this one. |
| 55 | + TokensUpToHere int |
| 56 | + |
| 57 | + // StableSinceTurn is the push count at which this entry became stable. |
| 58 | + StableSinceTurn int |
| 59 | +} |
| 60 | + |
| 61 | +// CacheBoundaryResult is the output of a boundary evaluation. |
| 62 | +type CacheBoundaryResult struct { |
| 63 | + // Markers lists the recommended cache_control placements in order. |
| 64 | + Markers []CacheBoundaryMarker |
| 65 | + |
| 66 | + // TotalStableTokens is the combined token count of all stable entries. |
| 67 | + TotalStableTokens int |
| 68 | + |
| 69 | + // Advanced is true when the boundary moved forward since the last push. |
| 70 | + Advanced bool |
| 71 | + |
| 72 | + // Retreated is true when the boundary moved backward (content changed). |
| 73 | + Retreated bool |
| 74 | +} |
| 75 | + |
| 76 | +// CacheBoundaryManager evaluates the optimal cache_control placement after |
| 77 | +// each session push. It is embedded in SQLiteStore and called automatically |
| 78 | +// by Push when boundary management is enabled. |
| 79 | +type CacheBoundaryManager struct { |
| 80 | + db *sql.DB |
| 81 | + cfg CacheBoundaryConfig |
| 82 | +} |
| 83 | + |
| 84 | +// newCacheBoundaryManager creates a manager backed by the given database. |
| 85 | +func newCacheBoundaryManager(db *sql.DB, cfg CacheBoundaryConfig) *CacheBoundaryManager { |
| 86 | + return &CacheBoundaryManager{db: db, cfg: cfg} |
| 87 | +} |
| 88 | + |
| 89 | +// Evaluate computes the current optimal cache boundary for a session. |
| 90 | +// It returns the recommended markers and whether the boundary changed. |
| 91 | +func (m *CacheBoundaryManager) Evaluate(ctx context.Context, sessionID string) (*CacheBoundaryResult, error) { |
| 92 | + if !m.cfg.Enabled { |
| 93 | + return &CacheBoundaryResult{}, nil |
| 94 | + } |
| 95 | + |
| 96 | + // Load entries ordered by sequence. |
| 97 | + rows, err := m.db.QueryContext(ctx, |
| 98 | + `SELECT id, tokens, stable_since_turn, content_hash |
| 99 | + FROM session_entries |
| 100 | + WHERE session_id = ? |
| 101 | + ORDER BY seq ASC`, |
| 102 | + sessionID, |
| 103 | + ) |
| 104 | + if err != nil { |
| 105 | + return nil, fmt.Errorf("query entries: %w", err) |
| 106 | + } |
| 107 | + |
| 108 | + type entryRow struct { |
| 109 | + id string |
| 110 | + tokens int |
| 111 | + stableSince int |
| 112 | + contentHash string |
| 113 | + } |
| 114 | + var entries []entryRow |
| 115 | + for rows.Next() { |
| 116 | + var e entryRow |
| 117 | + if err := rows.Scan(&e.id, &e.tokens, &e.stableSince, &e.contentHash); err != nil { |
| 118 | + _ = rows.Close() |
| 119 | + return nil, err |
| 120 | + } |
| 121 | + entries = append(entries, e) |
| 122 | + } |
| 123 | + if err := rows.Err(); err != nil { |
| 124 | + _ = rows.Close() |
| 125 | + return nil, err |
| 126 | + } |
| 127 | + _ = rows.Close() |
| 128 | + |
| 129 | + minStable := m.cfg.MinStableTurns |
| 130 | + result := &CacheBoundaryResult{} |
| 131 | + cumTokens := 0 |
| 132 | + |
| 133 | + type candidate struct { |
| 134 | + entryID string |
| 135 | + cumTokens int |
| 136 | + stableSince int |
| 137 | + } |
| 138 | + var candidates []candidate |
| 139 | + |
| 140 | + for _, e := range entries { |
| 141 | + cumTokens += e.tokens |
| 142 | + if e.stableSince > 0 && e.stableSince <= minStable { |
| 143 | + // stable_since_turn is the push count when it first appeared; |
| 144 | + // it is considered stable once it has survived minStable pushes. |
| 145 | + // We store the push count at insertion; the current push count is |
| 146 | + // derived from the max stable_since_turn in the table. |
| 147 | + candidates = append(candidates, candidate{ |
| 148 | + entryID: e.id, |
| 149 | + cumTokens: cumTokens, |
| 150 | + stableSince: e.stableSince, |
| 151 | + }) |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + // Filter: only include entries whose cumulative token count meets the |
| 156 | + // minimum prefix requirement. |
| 157 | + var eligible []candidate |
| 158 | + for _, c := range candidates { |
| 159 | + if c.cumTokens >= m.cfg.MinPrefixTokens { |
| 160 | + eligible = append(eligible, c) |
| 161 | + } |
| 162 | + } |
| 163 | + |
| 164 | + // Sort by cumulative tokens descending to pick the largest stable prefixes. |
| 165 | + sort.Slice(eligible, func(i, j int) bool { |
| 166 | + return eligible[i].cumTokens > eligible[j].cumTokens |
| 167 | + }) |
| 168 | + |
| 169 | + // Cap at MaxMarkers. |
| 170 | + if len(eligible) > m.cfg.MaxMarkers { |
| 171 | + eligible = eligible[:m.cfg.MaxMarkers] |
| 172 | + } |
| 173 | + |
| 174 | + // Re-sort by cumTokens ascending so markers are in document order. |
| 175 | + sort.Slice(eligible, func(i, j int) bool { |
| 176 | + return eligible[i].cumTokens < eligible[j].cumTokens |
| 177 | + }) |
| 178 | + |
| 179 | + for _, c := range eligible { |
| 180 | + result.Markers = append(result.Markers, CacheBoundaryMarker{ |
| 181 | + EntryID: c.entryID, |
| 182 | + TokensUpToHere: c.cumTokens, |
| 183 | + StableSinceTurn: c.stableSince, |
| 184 | + }) |
| 185 | + result.TotalStableTokens = c.cumTokens |
| 186 | + } |
| 187 | + |
| 188 | + // Detect advance/retreat by comparing with the stored boundary. |
| 189 | + prev, err := m.loadStoredBoundary(ctx, sessionID) |
| 190 | + if err == nil { |
| 191 | + if result.TotalStableTokens > prev { |
| 192 | + result.Advanced = true |
| 193 | + } else if result.TotalStableTokens < prev && prev > 0 { |
| 194 | + result.Retreated = true |
| 195 | + } |
| 196 | + } |
| 197 | + |
| 198 | + // Persist the new boundary position. |
| 199 | + _ = m.storeBoundary(ctx, sessionID, result.TotalStableTokens) |
| 200 | + |
| 201 | + return result, nil |
| 202 | +} |
| 203 | + |
| 204 | +// loadStoredBoundary retrieves the last recorded boundary token count. |
| 205 | +func (m *CacheBoundaryManager) loadStoredBoundary(ctx context.Context, sessionID string) (int, error) { |
| 206 | + var tokens int |
| 207 | + err := m.db.QueryRowContext(ctx, |
| 208 | + "SELECT cache_boundary_tokens FROM sessions WHERE id = ?", |
| 209 | + sessionID, |
| 210 | + ).Scan(&tokens) |
| 211 | + if err != nil { |
| 212 | + return 0, err |
| 213 | + } |
| 214 | + return tokens, nil |
| 215 | +} |
| 216 | + |
| 217 | +// storeBoundary persists the current boundary token count. |
| 218 | +func (m *CacheBoundaryManager) storeBoundary(ctx context.Context, sessionID string, tokens int) error { |
| 219 | + _, err := m.db.ExecContext(ctx, |
| 220 | + "UPDATE sessions SET cache_boundary_tokens = ? WHERE id = ?", |
| 221 | + tokens, sessionID, |
| 222 | + ) |
| 223 | + return err |
| 224 | +} |
| 225 | + |
| 226 | +// RecordPush increments the push counter for a session and updates |
| 227 | +// stable_since_turn for entries that have now survived minStableTurns pushes. |
| 228 | +func (m *CacheBoundaryManager) RecordPush(ctx context.Context, sessionID string) error { |
| 229 | + if !m.cfg.Enabled { |
| 230 | + return nil |
| 231 | + } |
| 232 | + |
| 233 | + // Increment push count. |
| 234 | + _, err := m.db.ExecContext(ctx, |
| 235 | + "UPDATE sessions SET push_count = push_count + 1 WHERE id = ?", |
| 236 | + sessionID, |
| 237 | + ) |
| 238 | + if err != nil { |
| 239 | + return fmt.Errorf("increment push count: %w", err) |
| 240 | + } |
| 241 | + |
| 242 | + // Fetch current push count. |
| 243 | + var pushCount int |
| 244 | + if err := m.db.QueryRowContext(ctx, |
| 245 | + "SELECT push_count FROM sessions WHERE id = ?", |
| 246 | + sessionID, |
| 247 | + ).Scan(&pushCount); err != nil { |
| 248 | + return fmt.Errorf("read push count: %w", err) |
| 249 | + } |
| 250 | + |
| 251 | + // Mark entries as stable when they have survived minStableTurns pushes |
| 252 | + // without modification (stable_since_turn == 0 means not yet stable). |
| 253 | + stableThreshold := pushCount - m.cfg.MinStableTurns |
| 254 | + if stableThreshold > 0 { |
| 255 | + _, err = m.db.ExecContext(ctx, |
| 256 | + `UPDATE session_entries |
| 257 | + SET stable_since_turn = inserted_at_push |
| 258 | + WHERE session_id = ? |
| 259 | + AND stable_since_turn = 0 |
| 260 | + AND inserted_at_push <= ?`, |
| 261 | + sessionID, stableThreshold, |
| 262 | + ) |
| 263 | + if err != nil { |
| 264 | + return fmt.Errorf("mark stable entries: %w", err) |
| 265 | + } |
| 266 | + } |
| 267 | + |
| 268 | + return nil |
| 269 | +} |
| 270 | + |
| 271 | +// InvalidateEntry marks an entry as no longer stable (e.g. content changed). |
| 272 | +// The boundary will retreat on the next Evaluate call. |
| 273 | +func (m *CacheBoundaryManager) InvalidateEntry(ctx context.Context, entryID string) error { |
| 274 | + _, err := m.db.ExecContext(ctx, |
| 275 | + "UPDATE session_entries SET stable_since_turn = 0, content_hash = '' WHERE id = ?", |
| 276 | + entryID, |
| 277 | + ) |
| 278 | + return err |
| 279 | +} |
| 280 | + |
| 281 | +// BoundaryStats returns a snapshot of the current boundary state for a session. |
| 282 | +type BoundaryStats struct { |
| 283 | + SessionID string |
| 284 | + PushCount int |
| 285 | + BoundaryTokens int |
| 286 | + StableEntryCount int |
| 287 | + LastEvaluatedAt time.Time |
| 288 | +} |
| 289 | + |
| 290 | +// Stats returns boundary statistics for a session. |
| 291 | +func (m *CacheBoundaryManager) Stats(ctx context.Context, sessionID string) (*BoundaryStats, error) { |
| 292 | + var stats BoundaryStats |
| 293 | + stats.SessionID = sessionID |
| 294 | + |
| 295 | + err := m.db.QueryRowContext(ctx, |
| 296 | + "SELECT push_count, cache_boundary_tokens FROM sessions WHERE id = ?", |
| 297 | + sessionID, |
| 298 | + ).Scan(&stats.PushCount, &stats.BoundaryTokens) |
| 299 | + if err != nil { |
| 300 | + return nil, fmt.Errorf("read boundary stats: %w", err) |
| 301 | + } |
| 302 | + |
| 303 | + _ = m.db.QueryRowContext(ctx, |
| 304 | + "SELECT COUNT(*) FROM session_entries WHERE session_id = ? AND stable_since_turn > 0", |
| 305 | + sessionID, |
| 306 | + ).Scan(&stats.StableEntryCount) |
| 307 | + |
| 308 | + stats.LastEvaluatedAt = time.Now() |
| 309 | + return &stats, nil |
| 310 | +} |
0 commit comments