diff --git a/betamessageutil.go b/betamessageutil.go index 2d9f41e6..e0eb9b38 100644 --- a/betamessageutil.go +++ b/betamessageutil.go @@ -6,6 +6,7 @@ import ( "github.com/charmbracelet/anthropic-sdk-go/internal/paramutil" "github.com/charmbracelet/anthropic-sdk-go/packages/param" + "github.com/charmbracelet/anthropic-sdk-go/packages/respjson" ) // Accumulate builds up the Message incrementally from a MessageStreamEvent. The Message then can be used as @@ -28,8 +29,28 @@ func (acc *BetaMessage) Accumulate(event BetaRawMessageStreamEventUnion) error { case BetaRawMessageDeltaEvent: acc.StopReason = event.Delta.StopReason acc.StopSequence = event.Delta.StopSequence - acc.Usage.OutputTokens = event.Usage.OutputTokens - acc.Usage.Iterations = event.Usage.Iterations + // Merge only fields present in the delta JSON, so earlier cumulative + // values (e.g. input and cache tokens from message_start) are preserved + // when a later delta omits them. Check Raw() against respjson.Omitted + // rather than using Valid(), because only JSON presence matters here. + if event.Usage.JSON.CacheCreationInputTokens.Raw() != respjson.Omitted { + acc.Usage.CacheCreationInputTokens = event.Usage.CacheCreationInputTokens + } + if event.Usage.JSON.CacheReadInputTokens.Raw() != respjson.Omitted { + acc.Usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens + } + if event.Usage.JSON.InputTokens.Raw() != respjson.Omitted { + acc.Usage.InputTokens = event.Usage.InputTokens + } + if event.Usage.JSON.Iterations.Raw() != respjson.Omitted { + acc.Usage.Iterations = event.Usage.Iterations + } + if event.Usage.JSON.OutputTokens.Raw() != respjson.Omitted { + acc.Usage.OutputTokens = event.Usage.OutputTokens + } + if event.Usage.JSON.ServerToolUse.Raw() != respjson.Omitted { + acc.Usage.ServerToolUse = event.Usage.ServerToolUse + } acc.ContextManagement = event.ContextManagement case BetaRawContentBlockStartEvent: acc.Content = append(acc.Content, BetaContentBlockUnion{}) diff --git a/examples/go.mod b/examples/go.mod index 353f9ba5..1782f218 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -1,13 +1,13 @@ -module github.com/anthropic/anthropic-sdk-go/examples +module github.com/charmbracelet/anthropic-sdk-go/examples -replace github.com/anthropics/anthropic-sdk-go => ../ +replace github.com/charmbracelet/anthropic-sdk-go => ../ go 1.23.0 toolchain go1.24.3 require ( - github.com/anthropics/anthropic-sdk-go v0.0.0-00010101000000-000000000000 + github.com/charmbracelet/anthropic-sdk-go v0.0.0-00010101000000-000000000000 github.com/invopop/jsonschema v0.13.0 ) diff --git a/examples/go.sum b/examples/go.sum index 9d111685..65107bd7 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -44,6 +44,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -195,6 +197,8 @@ google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6h google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/message_test.go b/message_test.go index 52efb81a..52d9df05 100644 --- a/message_test.go +++ b/message_test.go @@ -249,6 +249,43 @@ func TestAccumulate(t *testing.T) { `{"type: "message_stop"}`, }, }, + "usage tokens from message_start and message_delta": { + events: []string{ + `{"type": "message_start", "message": {"usage": {"input_tokens": 100, "output_tokens": 0}}}`, + `{"type": "message_delta", "delta": {}, "usage": {"output_tokens": 50}}`, + `{"type": "message_stop"}`, + }, + expected: anthropic.Message{Usage: anthropic.Usage{ + InputTokens: 100, + OutputTokens: 50, + }}, + }, + "cache tokens preserved through message_delta": { + events: []string{ + `{"type": "message_start", "message": {"usage": {"input_tokens": 200, "output_tokens": 0, "cache_creation_input_tokens": 30, "cache_read_input_tokens": 150}}}`, + `{"type": "message_delta", "delta": {}, "usage": {"output_tokens": 75, "cache_creation_input_tokens": 35, "cache_read_input_tokens": 160}}`, + `{"type": "message_stop"}`, + }, + expected: anthropic.Message{Usage: anthropic.Usage{ + InputTokens: 200, + OutputTokens: 75, + CacheCreationInputTokens: 35, + CacheReadInputTokens: 160, + }}, + }, + "message_delta does not clobber message_start usage": { + events: []string{ + `{"type": "message_start", "message": {"usage": {"input_tokens": 111, "output_tokens": 5, "cache_creation_input_tokens": 22, "cache_read_input_tokens": 33}}}`, + `{"type": "message_delta", "delta": {}, "usage": {"output_tokens": 60}}`, + `{"type": "message_stop"}`, + }, + expected: anthropic.Message{Usage: anthropic.Usage{ + InputTokens: 111, + OutputTokens: 60, + CacheCreationInputTokens: 22, + CacheReadInputTokens: 33, + }}, + }, "text content block": { events: []string{ `{"type": "message_start", "message": {}}`, diff --git a/messageutil.go b/messageutil.go index ee0cd3f0..64766f39 100644 --- a/messageutil.go +++ b/messageutil.go @@ -6,6 +6,7 @@ import ( "github.com/charmbracelet/anthropic-sdk-go/internal/paramutil" "github.com/charmbracelet/anthropic-sdk-go/packages/param" + "github.com/charmbracelet/anthropic-sdk-go/packages/respjson" ) // Accumulate builds up the Message incrementally from a MessageStreamEvent. The Message then can be used as @@ -28,7 +29,25 @@ func (acc *Message) Accumulate(event MessageStreamEventUnion) error { case MessageDeltaEvent: acc.StopReason = event.Delta.StopReason acc.StopSequence = event.Delta.StopSequence - acc.Usage.OutputTokens = event.Usage.OutputTokens + // Merge only fields present in the delta JSON, so earlier cumulative + // values (e.g. input and cache tokens from message_start) are preserved + // when a later delta omits them. Check Raw() against respjson.Omitted + // rather than using Valid(), because only JSON presence matters here. + if event.Usage.JSON.CacheCreationInputTokens.Raw() != respjson.Omitted { + acc.Usage.CacheCreationInputTokens = event.Usage.CacheCreationInputTokens + } + if event.Usage.JSON.CacheReadInputTokens.Raw() != respjson.Omitted { + acc.Usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens + } + if event.Usage.JSON.InputTokens.Raw() != respjson.Omitted { + acc.Usage.InputTokens = event.Usage.InputTokens + } + if event.Usage.JSON.OutputTokens.Raw() != respjson.Omitted { + acc.Usage.OutputTokens = event.Usage.OutputTokens + } + if event.Usage.JSON.ServerToolUse.Raw() != respjson.Omitted { + acc.Usage.ServerToolUse = event.Usage.ServerToolUse + } case ContentBlockStartEvent: acc.Content = append(acc.Content, ContentBlockUnion{}) err := acc.Content[len(acc.Content)-1].UnmarshalJSON([]byte(event.ContentBlock.RawJSON()))