Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
providertests/testdata/**/*.yaml -diff linguist-generated=true
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ require (
github.com/charmbracelet/x/json v0.2.0
github.com/go-viper/mapstructure/v2 v2.4.0
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/openai/openai-go/v2 v2.3.0
github.com/stretchr/testify v1.11.1
gopkg.in/dnaeon/go-vcr.v4 v4.0.5
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/openai/openai-go/v2 v2.3.0 h1:y9U+V1tlHjvvb/5XIswuySqnG5EnKBFAbMxgBvTHXvg=
github.com/openai/openai-go/v2 v2.3.0/go.mod h1:sIUkR+Cu/PMUVkSKhkk742PRURkQOCFhiwJ7eRSBqmk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand All @@ -26,5 +30,7 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/dnaeon/go-vcr.v4 v4.0.5 h1:I0hpTIvD5rII+8LgYGrHMA2d4SQPoL6u7ZvJakWKsiA=
gopkg.in/dnaeon/go-vcr.v4 v4.0.5/go.mod h1:dRos81TkW9C1WJt6tTaE+uV2Lo8qJT3AG2b35+CB/nQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
2 changes: 2 additions & 0 deletions providertests/.env.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ANTHROPIC_API_KEY=
OPENAI_API_KEY=
48 changes: 48 additions & 0 deletions providertests/builders_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package providertests

import (
"net/http"
"os"

"github.com/charmbracelet/ai/ai"
"github.com/charmbracelet/ai/anthropic"
"github.com/charmbracelet/ai/openai"
"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
)

type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)

type builderPair struct {
name string
builder builderFunc
}

var languageModelBuilders = []builderPair{
{"openai-gpt-4o", builderOpenaiGpt4o},
{"openai-gpt-4o-mini", builderOpenaiGpt4oMini},
{"anthropic-claude-sonnet", builderAnthropicClaudeSonnet4},
}

func builderOpenaiGpt4o(r *recorder.Recorder) (ai.LanguageModel, error) {
provider := openai.New(
openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")),
openai.WithHTTPClient(&http.Client{Transport: r}),
)
return provider.LanguageModel("gpt-4o")
}

func builderOpenaiGpt4oMini(r *recorder.Recorder) (ai.LanguageModel, error) {
provider := openai.New(
openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")),
openai.WithHTTPClient(&http.Client{Transport: r}),
)
return provider.LanguageModel("gpt-4o-mini")
}

func builderAnthropicClaudeSonnet4(r *recorder.Recorder) (ai.LanguageModel, error) {
provider := anthropic.New(
anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")),
anthropic.WithHTTPClient(&http.Client{Transport: r}),
)
return provider.LanguageModel("claude-sonnet-4-20250514")
}
223 changes: 223 additions & 0 deletions providertests/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
package providertests

import (
"context"
"strconv"
"strings"
"testing"

"github.com/charmbracelet/ai/ai"
_ "github.com/joho/godotenv/autoload"
)

func TestSimple(t *testing.T) {
for _, pair := range languageModelBuilders {
t.Run(pair.name, func(t *testing.T) {
r := newRecorder(t)

languageModel, err := pair.builder(r)
if err != nil {
t.Fatalf("failed to build language model: %v", err)
}

agent := ai.NewAgent(
languageModel,
ai.WithSystemPrompt("You are a helpful assistant"),
)
result, err := agent.Generate(t.Context(), ai.AgentCall{
Prompt: "Say hi in Portuguese",
})
if err != nil {
t.Fatalf("failed to generate: %v", err)
}

option1 := "Oi"
option2 := "Olá"
got := result.Response.Content.Text()
if !strings.Contains(got, option1) && !strings.Contains(got, option2) {
t.Fatalf("unexpected response: got %q, want %q or %q", got, option1, option2)
}
})
}
}

func TestTool(t *testing.T) {
for _, pair := range languageModelBuilders {
t.Run(pair.name, func(t *testing.T) {
r := newRecorder(t)

languageModel, err := pair.builder(r)
if err != nil {
t.Fatalf("failed to build language model: %v", err)
}

type WeatherInput struct {
Location string `json:"location" description:"the city"`
}

weatherTool := ai.NewAgentTool(
"weather",
"Get weather information for a location",
func(ctx context.Context, input WeatherInput, _ ai.ToolCall) (ai.ToolResponse, error) {
return ai.NewTextResponse("40 C"), nil
},
)

agent := ai.NewAgent(
languageModel,
ai.WithSystemPrompt("You are a helpful assistant"),
ai.WithTools(weatherTool),
)
result, err := agent.Generate(t.Context(), ai.AgentCall{
Prompt: "What's the weather in Florence?",
})
if err != nil {
t.Fatalf("failed to generate: %v", err)
}

want1 := "Florence"
want2 := "40"
got := result.Response.Content.Text()
if !strings.Contains(got, want1) || !strings.Contains(got, want2) {
t.Fatalf("unexpected response: got %q, want %q %q", got, want1, want2)
}
})
}
}

func TestStream(t *testing.T) {
for _, pair := range languageModelBuilders {
t.Run(pair.name, func(t *testing.T) {
r := newRecorder(t)

languageModel, err := pair.builder(r)
if err != nil {
t.Fatalf("failed to build language model: %v", err)
}

agent := ai.NewAgent(
languageModel,
ai.WithSystemPrompt("You are a helpful assistant"),
)

var collectedText strings.Builder
textDeltaCount := 0
stepCount := 0

streamCall := ai.AgentStreamCall{
Prompt: "Count from 1 to 3 in Spanish",
OnTextDelta: func(id, text string) error {
textDeltaCount++
collectedText.WriteString(text)
return nil
},
OnStepFinish: func(step ai.StepResult) error {
stepCount++
return nil
},
}

result, err := agent.Stream(t.Context(), streamCall)
if err != nil {
t.Fatalf("failed to stream: %v", err)
}

finalText := result.Response.Content.Text()
if finalText == "" {
t.Fatal("expected non-empty response")
}

if !strings.Contains(strings.ToLower(finalText), "uno") ||
!strings.Contains(strings.ToLower(finalText), "dos") ||
!strings.Contains(strings.ToLower(finalText), "tres") {
t.Fatalf("unexpected response: %q", finalText)
}

if textDeltaCount == 0 {
t.Fatal("expected at least one text delta callback")
}

if stepCount == 0 {
t.Fatal("expected at least one step finish callback")
}

if collectedText.String() == "" {
t.Fatal("expected collected text from deltas to be non-empty")
}
})
}
}

func TestStreamWithTools(t *testing.T) {
for _, pair := range languageModelBuilders {
t.Run(pair.name, func(t *testing.T) {
r := newRecorder(t)

languageModel, err := pair.builder(r)
if err != nil {
t.Fatalf("failed to build language model: %v", err)
}

type CalculatorInput struct {
A int `json:"a" description:"first number"`
B int `json:"b" description:"second number"`
}

calculatorTool := ai.NewAgentTool(
"add",
"Add two numbers",
func(ctx context.Context, input CalculatorInput, _ ai.ToolCall) (ai.ToolResponse, error) {
result := input.A + input.B
return ai.NewTextResponse(strings.TrimSpace(strconv.Itoa(result))), nil
},
)

agent := ai.NewAgent(
languageModel,
ai.WithSystemPrompt("You are a helpful assistant. Use the add tool to perform calculations."),
ai.WithTools(calculatorTool),
)

toolCallCount := 0
toolResultCount := 0
var collectedText strings.Builder

streamCall := ai.AgentStreamCall{
Prompt: "What is 15 + 27?",
OnTextDelta: func(id, text string) error {
collectedText.WriteString(text)
return nil
},
OnToolCall: func(toolCall ai.ToolCallContent) error {
toolCallCount++
if toolCall.ToolName != "add" {
t.Errorf("unexpected tool name: %s", toolCall.ToolName)
}
return nil
},
OnToolResult: func(result ai.ToolResultContent) error {
toolResultCount++
return nil
},
}

result, err := agent.Stream(t.Context(), streamCall)
if err != nil {
t.Fatalf("failed to stream: %v", err)
}

finalText := result.Response.Content.Text()
if !strings.Contains(finalText, "42") {
t.Fatalf("expected response to contain '42', got: %q", finalText)
}

if toolCallCount == 0 {
t.Fatal("expected at least one tool call")
}

if toolResultCount == 0 {
t.Fatal("expected at least one tool result")
}
})
}
}
75 changes: 75 additions & 0 deletions providertests/recorder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package providertests

import (
"bytes"
"io"
"net/http"
"path/filepath"
"strings"
"testing"

"gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
)

func newRecorder(t *testing.T) *recorder.Recorder {
cassetteName := filepath.Join("testdata", t.Name())

r, err := recorder.New(
cassetteName,
recorder.WithMode(recorder.ModeRecordOnce),
recorder.WithMatcher(customMatcher(t)),
recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
)
if err != nil {
t.Fatalf("recorder: failed to create recorder: %v", err)
}

t.Cleanup(func() {
if err := r.Stop(); err != nil {
t.Errorf("recorder: failed to stop recorder: %v", err)
}
})

return r
}

func customMatcher(t *testing.T) recorder.MatcherFunc {
return func(r *http.Request, i cassette.Request) bool {
if r.Body == nil || r.Body == http.NoBody {
return cassette.DefaultMatcher(r, i)
}

var reqBody []byte
var err error
reqBody, err = io.ReadAll(r.Body)
if err != nil {
t.Fatalf("recorder: failed to read request body")
}
r.Body.Close()
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))

return r.Method == i.Method && r.URL.String() == i.URL && string(reqBody) == i.Body
}
}

var headersToKeep = map[string]struct{}{
"accept": {},
"content-type": {},
"user-agent": {},
}

func hookRemoveHeaders(i *cassette.Interaction) error {
for k := range i.Request.Headers {
if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
delete(i.Request.Headers, k)
}
}
for k := range i.Response.Headers {
if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
delete(i.Response.Headers, k)
}
}
return nil
}
Loading
Loading