Skip to content

Commit b513c44

Browse files
committed
test: add more specific reasoning tests
1 parent fb3017c commit b513c44

1 file changed

Lines changed: 62 additions & 1 deletion

File tree

providertests/provider_test.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package providertests
22

33
import (
44
"context"
5+
"fmt"
56
"strconv"
67
"strings"
78
"testing"
@@ -10,8 +11,8 @@ import (
1011
"github.com/charmbracelet/fantasy/anthropic"
1112
"github.com/charmbracelet/fantasy/google"
1213
"github.com/charmbracelet/fantasy/openai"
13-
"github.com/stretchr/testify/require"
1414
_ "github.com/joho/godotenv/autoload"
15+
"github.com/stretchr/testify/require"
1516
)
1617

1718
func TestSimple(t *testing.T) {
@@ -127,6 +128,8 @@ func TestThinking(t *testing.T) {
127128
want2 := "40"
128129
got := result.Response.Content.Text()
129130
require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
131+
132+
testThinkingSteps(t, languageModel.Provider(), result.Steps)
130133
})
131134
}
132135
}
@@ -181,10 +184,68 @@ func TestThinkingStreaming(t *testing.T) {
181184
want2 := "40"
182185
got := result.Response.Content.Text()
183186
require.True(t, strings.Contains(got, want1) && strings.Contains(got, want2), "unexpected response: got %q, want %q %q", got, want1, want2)
187+
188+
testThinkingSteps(t, languageModel.Provider(), result.Steps)
184189
})
185190
}
186191
}
187192

193+
func testThinkingSteps(t *testing.T, providerName string, steps []ai.StepResult) {
194+
if providerName == anthropic.Name {
195+
reasoningContentCount := 0
196+
signaturesCount := 0
197+
// Test if we got the signature
198+
for _, step := range steps {
199+
for _, msg := range step.Messages {
200+
for _, content := range msg.Content {
201+
if content.GetType() == ai.ContentTypeReasoning {
202+
reasoningContentCount += 1
203+
reasoningContent, ok := ai.AsContentType[ai.ReasoningPart](content)
204+
if !ok {
205+
continue
206+
}
207+
if len(reasoningContent.ProviderOptions) == 0 {
208+
continue
209+
}
210+
211+
anthropicReasoningMetadata, ok := reasoningContent.ProviderOptions[anthropic.Name]
212+
if !ok {
213+
continue
214+
}
215+
if reasoningContent.Text != "" {
216+
if typed, ok := anthropicReasoningMetadata.(*anthropic.ReasoningOptionMetadata); ok {
217+
require.NotEmpty(t, typed.Signature)
218+
signaturesCount += 1
219+
}
220+
}
221+
}
222+
}
223+
}
224+
}
225+
require.Greater(t, reasoningContentCount, 0)
226+
require.Greater(t, signaturesCount, 0)
227+
require.Equal(t, reasoningContentCount, signaturesCount)
228+
} else if providerName == google.Name {
229+
reasoningContentCount := 0
230+
// Test if we got the signature
231+
for _, step := range steps {
232+
for _, msg := range step.Messages {
233+
for _, content := range msg.Content {
234+
if content.GetType() == ai.ContentTypeReasoning {
235+
reasoningContentCount += 1
236+
reasoningContent, ok := ai.AsContentType[ai.ReasoningContent](content)
237+
if !ok {
238+
continue
239+
}
240+
fmt.Println(reasoningContent.Text)
241+
}
242+
}
243+
}
244+
}
245+
require.Greater(t, reasoningContentCount, 0)
246+
}
247+
}
248+
188249
func TestStream(t *testing.T) {
189250
for _, pair := range languageModelBuilders {
190251
t.Run(pair.name, func(t *testing.T) {

0 commit comments

Comments
 (0)