Skip to content

Commit 9e4d5d1

Browse files
committed
feat(anthropic): add prompt caching support
1 parent 7fcc124 commit 9e4d5d1

6 files changed

Lines changed: 421 additions & 47 deletions

File tree

provider/anthropic/messages/messages.go

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,15 @@ func resolveMaxTokens(params *sdk.GenerateParams, thinking *ThinkingConfig) *int
264264
func convertTools(tools []sdk.Tool) []anthropicTool {
265265
out := make([]anthropicTool, 0, len(tools))
266266
for _, t := range tools {
267-
out = append(out, anthropicTool{
267+
at := anthropicTool{
268268
Name: t.Name,
269269
Description: t.Description,
270270
InputSchema: t.Parameters,
271-
})
271+
}
272+
if t.CacheControl != nil {
273+
at.CacheControl = &cacheControl{Type: t.CacheControl.Type, TTL: t.CacheControl.TTL}
274+
}
275+
out = append(out, at)
272276
}
273277
return out
274278
}
@@ -307,6 +311,11 @@ func convertToolChoice(choice any) *anthropicToolChoice {
307311
// convertMessages splits SDK messages into Anthropic's system blocks and
308312
// alternating user/assistant messages. Tool result messages are merged into
309313
// user messages, as required by the Anthropic API.
314+
//
315+
// Note: GenerateParams.System (plain string) is converted to a single system
316+
// block without cache_control. To attach cache_control to a system prompt,
317+
// pass it as a MessageRoleSystem message with a TextPart that has CacheControl
318+
// set instead of using the System field.
310319
func convertMessages(params *sdk.GenerateParams) ([]contentBlock, []anthropicMessage) {
311320
var system []contentBlock
312321
var out []anthropicMessage
@@ -320,7 +329,11 @@ func convertMessages(params *sdk.GenerateParams) ([]contentBlock, []anthropicMes
320329
case sdk.MessageRoleSystem:
321330
for _, part := range msg.Content {
322331
if tp, ok := part.(sdk.TextPart); ok {
323-
system = append(system, contentBlock{Type: blockTypeText, Text: tp.Text})
332+
block := contentBlock{Type: blockTypeText, Text: tp.Text}
333+
if tp.CacheControl != nil {
334+
block.CacheControl = &cacheControl{Type: tp.CacheControl.Type, TTL: tp.CacheControl.TTL}
335+
}
336+
system = append(system, block)
324337
}
325338
}
326339

@@ -358,11 +371,19 @@ func convertUserContent(parts []sdk.MessagePart) []contentBlock {
358371
for _, part := range parts {
359372
switch p := part.(type) {
360373
case sdk.TextPart:
361-
blocks = append(blocks, contentBlock{Type: blockTypeText, Text: p.Text})
374+
block := contentBlock{Type: blockTypeText, Text: p.Text}
375+
if p.CacheControl != nil {
376+
block.CacheControl = &cacheControl{Type: p.CacheControl.Type, TTL: p.CacheControl.TTL}
377+
}
378+
blocks = append(blocks, block)
362379
case sdk.ImagePart:
363380
blocks = append(blocks, convertImagePart(p))
364381
case sdk.FilePart:
365-
blocks = append(blocks, contentBlock{Type: blockTypeText, Text: p.Data})
382+
block := contentBlock{Type: blockTypeText, Text: p.Data}
383+
if p.CacheControl != nil {
384+
block.CacheControl = &cacheControl{Type: p.CacheControl.Type, TTL: p.CacheControl.TTL}
385+
}
386+
blocks = append(blocks, block)
366387
}
367388
}
368389
return blocks
@@ -376,6 +397,11 @@ func convertImagePart(p sdk.ImagePart) contentBlock {
376397
image := strings.TrimSpace(p.Image)
377398
mediaType := strings.TrimSpace(p.MediaType)
378399

400+
var cc *cacheControl
401+
if p.CacheControl != nil {
402+
cc = &cacheControl{Type: p.CacheControl.Type, TTL: p.CacheControl.TTL}
403+
}
404+
379405
if strings.HasPrefix(strings.ToLower(image), "http://") || strings.HasPrefix(strings.ToLower(image), "https://") {
380406
return contentBlock{
381407
Type: "image",
@@ -384,6 +410,7 @@ func convertImagePart(p sdk.ImagePart) contentBlock {
384410
MediaType: mediaType,
385411
URL: image,
386412
},
413+
CacheControl: cc,
387414
}
388415
}
389416

@@ -413,6 +440,7 @@ func convertImagePart(p sdk.ImagePart) contentBlock {
413440
MediaType: mediaType,
414441
Data: image,
415442
},
443+
CacheControl: cc,
416444
}
417445
}
418446

@@ -771,15 +799,20 @@ func generateID() string {
771799

772800
func convertUsage(u *messagesUsage) sdk.Usage {
773801
total := u.InputTokens + u.OutputTokens
802+
detail := sdk.InputTokenDetail{
803+
CacheReadTokens: u.CacheReadInputTokens,
804+
CacheWriteTokens: u.CacheCreationInputTokens,
805+
}
806+
if u.CacheCreation != nil {
807+
detail.CacheWrite5mTokens = u.CacheCreation.Ephemeral5mInputTokens
808+
detail.CacheWrite1hTokens = u.CacheCreation.Ephemeral1hInputTokens
809+
}
774810
return sdk.Usage{
775811
InputTokens: u.InputTokens,
776812
OutputTokens: u.OutputTokens,
777813
TotalTokens: total,
778814
CachedInputTokens: u.CacheReadInputTokens,
779-
InputTokenDetails: sdk.InputTokenDetail{
780-
CacheReadTokens: u.CacheReadInputTokens,
781-
CacheWriteTokens: u.CacheCreationInputTokens,
782-
},
815+
InputTokenDetails: detail,
783816
}
784817
}
785818

0 commit comments

Comments
 (0)