Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions betamessageutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/charmbracelet/anthropic-sdk-go/internal/paramutil"
"github.com/charmbracelet/anthropic-sdk-go/packages/param"
"github.com/charmbracelet/anthropic-sdk-go/packages/respjson"
)

// Accumulate builds up the Message incrementally from a MessageStreamEvent. The Message then can be used as
Expand All @@ -28,8 +29,28 @@ func (acc *BetaMessage) Accumulate(event BetaRawMessageStreamEventUnion) error {
case BetaRawMessageDeltaEvent:
acc.StopReason = event.Delta.StopReason
acc.StopSequence = event.Delta.StopSequence
acc.Usage.OutputTokens = event.Usage.OutputTokens
acc.Usage.Iterations = event.Usage.Iterations
// Merge only fields present in the delta JSON, so earlier cumulative
// values (e.g. input and cache tokens from message_start) are preserved
// when a later delta omits them. Check Raw() against respjson.Omitted
// rather than using Valid(), because only JSON presence matters here.
if event.Usage.JSON.CacheCreationInputTokens.Raw() != respjson.Omitted {
acc.Usage.CacheCreationInputTokens = event.Usage.CacheCreationInputTokens
}
if event.Usage.JSON.CacheReadInputTokens.Raw() != respjson.Omitted {
acc.Usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens
}
if event.Usage.JSON.InputTokens.Raw() != respjson.Omitted {
acc.Usage.InputTokens = event.Usage.InputTokens
}
if event.Usage.JSON.Iterations.Raw() != respjson.Omitted {
acc.Usage.Iterations = event.Usage.Iterations
}
if event.Usage.JSON.OutputTokens.Raw() != respjson.Omitted {
acc.Usage.OutputTokens = event.Usage.OutputTokens
}
if event.Usage.JSON.ServerToolUse.Raw() != respjson.Omitted {
acc.Usage.ServerToolUse = event.Usage.ServerToolUse
}
acc.ContextManagement = event.ContextManagement
case BetaRawContentBlockStartEvent:
acc.Content = append(acc.Content, BetaContentBlockUnion{})
Expand Down
37 changes: 37 additions & 0 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,43 @@ func TestAccumulate(t *testing.T) {
`{"type: "message_stop"}`,
},
},
"usage tokens from message_start and message_delta": {
events: []string{
`{"type": "message_start", "message": {"usage": {"input_tokens": 100, "output_tokens": 0}}}`,
`{"type": "message_delta", "delta": {}, "usage": {"output_tokens": 50}}`,
`{"type": "message_stop"}`,
},
expected: anthropic.Message{Usage: anthropic.Usage{
InputTokens: 100,
OutputTokens: 50,
}},
},
"cache tokens preserved through message_delta": {
events: []string{
`{"type": "message_start", "message": {"usage": {"input_tokens": 200, "output_tokens": 0, "cache_creation_input_tokens": 30, "cache_read_input_tokens": 150}}}`,
`{"type": "message_delta", "delta": {}, "usage": {"output_tokens": 75, "cache_creation_input_tokens": 35, "cache_read_input_tokens": 160}}`,
`{"type": "message_stop"}`,
},
expected: anthropic.Message{Usage: anthropic.Usage{
InputTokens: 200,
OutputTokens: 75,
CacheCreationInputTokens: 35,
CacheReadInputTokens: 160,
}},
},
"message_delta does not clobber message_start usage": {
events: []string{
`{"type": "message_start", "message": {"usage": {"input_tokens": 111, "output_tokens": 5, "cache_creation_input_tokens": 22, "cache_read_input_tokens": 33}}}`,
`{"type": "message_delta", "delta": {}, "usage": {"output_tokens": 60}}`,
`{"type": "message_stop"}`,
},
expected: anthropic.Message{Usage: anthropic.Usage{
InputTokens: 111,
OutputTokens: 60,
CacheCreationInputTokens: 22,
CacheReadInputTokens: 33,
}},
},
"text content block": {
events: []string{
`{"type": "message_start", "message": {}}`,
Expand Down
21 changes: 20 additions & 1 deletion messageutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/charmbracelet/anthropic-sdk-go/internal/paramutil"
"github.com/charmbracelet/anthropic-sdk-go/packages/param"
"github.com/charmbracelet/anthropic-sdk-go/packages/respjson"
)

// Accumulate builds up the Message incrementally from a MessageStreamEvent. The Message then can be used as
Expand All @@ -28,7 +29,25 @@ func (acc *Message) Accumulate(event MessageStreamEventUnion) error {
case MessageDeltaEvent:
acc.StopReason = event.Delta.StopReason
acc.StopSequence = event.Delta.StopSequence
acc.Usage.OutputTokens = event.Usage.OutputTokens
// Merge only fields present in the delta JSON, so earlier cumulative
// values (e.g. input and cache tokens from message_start) are preserved
// when a later delta omits them. Check Raw() against respjson.Omitted
// rather than using Valid(), because only JSON presence matters here.
if event.Usage.JSON.CacheCreationInputTokens.Raw() != respjson.Omitted {
acc.Usage.CacheCreationInputTokens = event.Usage.CacheCreationInputTokens
}
if event.Usage.JSON.CacheReadInputTokens.Raw() != respjson.Omitted {
acc.Usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens
}
if event.Usage.JSON.InputTokens.Raw() != respjson.Omitted {
acc.Usage.InputTokens = event.Usage.InputTokens
}
if event.Usage.JSON.OutputTokens.Raw() != respjson.Omitted {
acc.Usage.OutputTokens = event.Usage.OutputTokens
}
if event.Usage.JSON.ServerToolUse.Raw() != respjson.Omitted {
acc.Usage.ServerToolUse = event.Usage.ServerToolUse
}
case ContentBlockStartEvent:
acc.Content = append(acc.Content, ContentBlockUnion{})
err := acc.Content[len(acc.Content)-1].UnmarshalJSON([]byte(event.ContentBlock.RawJSON()))
Expand Down
Loading