Skip to content

Commit e48c6a1

Browse files
committed
test: add tests for stream with and without tool calls
1 parent 8048d48 commit e48c6a1

5 files changed

Lines changed: 324 additions & 0 deletions

File tree

providertests/provider_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package providertests
22

33
import (
44
"context"
5+
"strconv"
56
"strings"
67
"testing"
78

@@ -82,3 +83,140 @@ func TestTool(t *testing.T) {
8283
})
8384
}
8485
}
86+
87+
func TestStream(t *testing.T) {
88+
for _, pair := range languageModelBuilders {
89+
t.Run(pair.name, func(t *testing.T) {
90+
r := newRecorder(t)
91+
92+
languageModel, err := pair.builder(r)
93+
if err != nil {
94+
t.Fatalf("failed to build language model: %v", err)
95+
}
96+
97+
agent := ai.NewAgent(
98+
languageModel,
99+
ai.WithSystemPrompt("You are a helpful assistant"),
100+
)
101+
102+
var collectedText strings.Builder
103+
textDeltaCount := 0
104+
stepCount := 0
105+
106+
streamCall := ai.AgentStreamCall{
107+
Prompt: "Count from 1 to 3 in Spanish",
108+
OnTextDelta: func(id, text string) error {
109+
textDeltaCount++
110+
collectedText.WriteString(text)
111+
return nil
112+
},
113+
OnStepFinish: func(step ai.StepResult) error {
114+
stepCount++
115+
return nil
116+
},
117+
}
118+
119+
result, err := agent.Stream(t.Context(), streamCall)
120+
if err != nil {
121+
t.Fatalf("failed to stream: %v", err)
122+
}
123+
124+
finalText := result.Response.Content.Text()
125+
if finalText == "" {
126+
t.Fatal("expected non-empty response")
127+
}
128+
129+
if !strings.Contains(strings.ToLower(finalText), "uno") ||
130+
!strings.Contains(strings.ToLower(finalText), "dos") ||
131+
!strings.Contains(strings.ToLower(finalText), "tres") {
132+
t.Fatalf("unexpected response: %q", finalText)
133+
}
134+
135+
if textDeltaCount == 0 {
136+
t.Fatal("expected at least one text delta callback")
137+
}
138+
139+
if stepCount == 0 {
140+
t.Fatal("expected at least one step finish callback")
141+
}
142+
143+
if collectedText.String() == "" {
144+
t.Fatal("expected collected text from deltas to be non-empty")
145+
}
146+
})
147+
}
148+
}
149+
150+
func TestStreamWithTools(t *testing.T) {
151+
for _, pair := range languageModelBuilders {
152+
t.Run(pair.name, func(t *testing.T) {
153+
r := newRecorder(t)
154+
155+
languageModel, err := pair.builder(r)
156+
if err != nil {
157+
t.Fatalf("failed to build language model: %v", err)
158+
}
159+
160+
type CalculatorInput struct {
161+
A int `json:"a" description:"first number"`
162+
B int `json:"b" description:"second number"`
163+
}
164+
165+
calculatorTool := ai.NewAgentTool(
166+
"add",
167+
"Add two numbers",
168+
func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
169+
result := input.A + input.B
170+
return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
171+
},
172+
)
173+
174+
agent := ai.NewAgent(
175+
languageModel,
176+
ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
177+
ai.WithTools(calculatorTool),
178+
)
179+
180+
toolCallCount := 0
181+
toolResultCount := 0
182+
var collectedText strings.Builder
183+
184+
streamCall := ai.AgentStreamCall{
185+
Prompt: "What is 15 + 27?",
186+
OnTextDelta: func(id, text string) error {
187+
collectedText.WriteString(text)
188+
return nil
189+
},
190+
OnToolCall: func(toolCall ai.ToolCallContent) error {
191+
toolCallCount++
192+
if toolCall.ToolName != "add" {
193+
t.Errorf("unexpected tool name: %s", toolCall.ToolName)
194+
}
195+
return nil
196+
},
197+
OnToolResult: func(result ai.ToolResultContent) error {
198+
toolResultCount++
199+
return nil
200+
},
201+
}
202+
203+
result, err := agent.Stream(t.Context(), streamCall)
204+
if err != nil {
205+
t.Fatalf("failed to stream: %v", err)
206+
}
207+
208+
finalText := result.Response.Content.Text()
209+
if !strings.Contains(finalText, "42") {
210+
t.Fatalf("expected response to contain '42', got: %q", finalText)
211+
}
212+
213+
if toolCallCount == 0 {
214+
t.Fatal("expected at least one tool call")
215+
}
216+
217+
if toolResultCount == 0 {
218+
t.Fatal("expected at least one tool result")
219+
}
220+
})
221+
}
222+
}

providertests/testdata/TestStream/anthropic-claude-sonnet.yaml

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

providertests/testdata/TestStream/openai-gpt-4o.yaml

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)