Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 99 additions & 42 deletions agent.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai

import (
"cmp"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -153,6 +154,75 @@ type AgentCall struct {
RepairToolCall RepairToolCallFunction
}

// Agent-level callbacks.
type (
// OnAgentStartFunc is called when agent starts.
OnAgentStartFunc func()

// OnAgentFinishFunc is called when agent finishes.
OnAgentFinishFunc func(result *AgentResult) error

// OnStepStartFunc is called when a step starts.
OnStepStartFunc func(stepNumber int) error

// OnStepFinishFunc is called when a step finishes.
OnStepFinishFunc func(stepResult StepResult) error

// OnFinishFunc is called when entire agent completes.
OnFinishFunc func(result *AgentResult)

// OnErrorFunc is called when an error occurs.
OnErrorFunc func(error)
)

// Stream part callbacks - called for each corresponding stream part type.
type (
// OnChunkFunc is called for each stream part (catch-all).
OnChunkFunc func(StreamPart) error

// OnWarningsFunc is called for warnings.
OnWarningsFunc func(warnings []CallWarning) error

// OnTextStartFunc is called when text starts.
OnTextStartFunc func(id string) error

// OnTextDeltaFunc is called for text deltas.
OnTextDeltaFunc func(id, text string) error

// OnTextEndFunc is called when text ends.
OnTextEndFunc func(id string) error

// OnReasoningStartFunc is called when reasoning starts.
OnReasoningStartFunc func(id string) error

// OnReasoningDeltaFunc is called for reasoning deltas.
OnReasoningDeltaFunc func(id, text string) error

// OnReasoningEndFunc is called when reasoning ends.
OnReasoningEndFunc func(id string, reasoning ReasoningContent) error

// OnToolInputStartFunc is called when tool input starts.
OnToolInputStartFunc func(id, toolName string) error

// OnToolInputDeltaFunc is called for tool input deltas.
OnToolInputDeltaFunc func(id, delta string) error

// OnToolInputEndFunc is called when tool input ends.
OnToolInputEndFunc func(id string) error

// OnToolCallFunc is called when tool call is complete.
OnToolCallFunc func(toolCall ToolCallContent) error

// OnToolResultFunc is called when tool execution completes.
OnToolResultFunc func(result ToolResultContent) error

// OnSourceFunc is called for source references.
OnSourceFunc func(source SourceContent) error

// OnStreamFinishFunc is called when stream finishes.
OnStreamFinishFunc func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error
)

type AgentStreamCall struct {
Prompt string `json:"prompt"`
Files []FilePart `json:"files"`
Expand All @@ -174,29 +244,29 @@ type AgentStreamCall struct {
RepairToolCall RepairToolCallFunction

// Agent-level callbacks
OnAgentStart func() // Called when agent starts
OnAgentFinish func(result *AgentResult) error // Called when agent finishes
OnStepStart func(stepNumber int) error // Called when a step starts
OnStepFinish func(stepResult StepResult) error // Called when a step finishes
OnFinish func(result *AgentResult) // Called when entire agent completes
OnError func(error) // Called when an error occurs
OnAgentStart OnAgentStartFunc // Called when agent starts
OnAgentFinish OnAgentFinishFunc // Called when agent finishes
OnStepStart OnStepStartFunc // Called when a step starts
OnStepFinish OnStepFinishFunc // Called when a step finishes
OnFinish OnFinishFunc // Called when entire agent completes
OnError OnErrorFunc // Called when an error occurs

// Stream part callbacks - called for each corresponding stream part type
OnChunk func(StreamPart) error // Called for each stream part (catch-all)
OnWarnings func(warnings []CallWarning) error // Called for warnings
OnTextStart func(id string) error // Called when text starts
OnTextDelta func(id, text string) error // Called for text deltas
OnTextEnd func(id string) error // Called when text ends
OnReasoningStart func(id string) error // Called when reasoning starts
OnReasoningDelta func(id, text string) error // Called for reasoning deltas
OnReasoningEnd func(id string, reasoning ReasoningContent) error // Called when reasoning ends
OnToolInputStart func(id, toolName string) error // Called when tool input starts
OnToolInputDelta func(id, delta string) error // Called for tool input deltas
OnToolInputEnd func(id string) error // Called when tool input ends
OnToolCall func(toolCall ToolCallContent) error // Called when tool call is complete
OnToolResult func(result ToolResultContent) error // Called when tool execution completes
OnSource func(source SourceContent) error // Called for source references
OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes
OnChunk OnChunkFunc // Called for each stream part (catch-all)
OnWarnings OnWarningsFunc // Called for warnings
OnTextStart OnTextStartFunc // Called when text starts
OnTextDelta OnTextDeltaFunc // Called for text deltas
OnTextEnd OnTextEndFunc // Called when text ends
OnReasoningStart OnReasoningStartFunc // Called when reasoning starts
OnReasoningDelta OnReasoningDeltaFunc // Called for reasoning deltas
OnReasoningEnd OnReasoningEndFunc // Called when reasoning ends
OnToolInputStart OnToolInputStartFunc // Called when tool input starts
OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
OnToolCall OnToolCallFunc // Called when tool call is complete
OnToolResult OnToolResultFunc // Called when tool execution completes
OnSource OnSourceFunc // Called for source references
OnStreamFinish OnStreamFinishFunc // Called when stream finishes
}

type AgentResult struct {
Expand Down Expand Up @@ -230,24 +300,14 @@ func NewAgent(model LanguageModel, opts ...AgentOption) Agent {
}

func (a *agent) prepareCall(call AgentCall) AgentCall {
if call.MaxOutputTokens == nil && a.settings.maxOutputTokens != nil {
call.MaxOutputTokens = a.settings.maxOutputTokens
}
if call.Temperature == nil && a.settings.temperature != nil {
call.Temperature = a.settings.temperature
}
if call.TopP == nil && a.settings.topP != nil {
call.TopP = a.settings.topP
}
if call.TopK == nil && a.settings.topK != nil {
call.TopK = a.settings.topK
}
if call.PresencePenalty == nil && a.settings.presencePenalty != nil {
call.PresencePenalty = a.settings.presencePenalty
}
if call.FrequencyPenalty == nil && a.settings.frequencyPenalty != nil {
call.FrequencyPenalty = a.settings.frequencyPenalty
}
call.MaxOutputTokens = cmp.Or(call.MaxOutputTokens, a.settings.maxOutputTokens)
call.Temperature = cmp.Or(call.Temperature, a.settings.temperature)
call.TopP = cmp.Or(call.TopP, a.settings.topP)
call.TopK = cmp.Or(call.TopK, a.settings.topK)
call.PresencePenalty = cmp.Or(call.PresencePenalty, a.settings.presencePenalty)
call.FrequencyPenalty = cmp.Or(call.FrequencyPenalty, a.settings.frequencyPenalty)
call.MaxRetries = cmp.Or(call.MaxRetries, a.settings.maxRetries)

if len(call.StopWhen) == 0 && len(a.settings.stopWhen) > 0 {
call.StopWhen = a.settings.stopWhen
}
Expand All @@ -260,9 +320,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
if call.OnRetry == nil && a.settings.onRetry != nil {
call.OnRetry = a.settings.onRetry
}
if call.MaxRetries == nil && a.settings.maxRetries != nil {
call.MaxRetries = a.settings.maxRetries
}

providerOptions := ProviderOptions{}
if a.settings.providerOptions != nil {
Expand Down
14 changes: 9 additions & 5 deletions examples/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ import (
"os"

"github.com/charmbracelet/ai"
"github.com/charmbracelet/ai/providers"
"github.com/charmbracelet/ai/providers/openai"
)

func main() {
provider := providers.NewOpenAiProvider(
providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY")),
provider := openai.New(
openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")),
)
model, err := provider.LanguageModel("gpt-4o")
if err != nil {
fmt.Println(err)
return
os.Exit(1)
}

// Create weather tool using the new type-safe API
Expand All @@ -38,9 +38,13 @@ func main() {
ai.WithTools(weatherTool),
)

result, _ := agent.Generate(context.Background(), ai.AgentCall{
result, err := agent.Generate(context.Background(), ai.AgentCall{
Prompt: "What's the weather in pristina",
})
if err != nil {
fmt.Println(err)
os.Exit(1)
}

fmt.Println("Steps: ", len(result.Steps))
for _, s := range result.Steps {
Expand Down
8 changes: 4 additions & 4 deletions examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ import (
"os"

"github.com/charmbracelet/ai"
"github.com/charmbracelet/ai/providers"
"github.com/charmbracelet/ai/providers/anthropic"
)

func main() {
provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
model, err := provider.LanguageModel("claude-sonnet-4-20250514")
if err != nil {
fmt.Println(err)
return
os.Exit(1)
}

response, err := model.Generate(context.Background(), ai.Call{
Expand All @@ -25,7 +25,7 @@ func main() {
})
if err != nil {
fmt.Println(err)
return
os.Exit(1)
}

fmt.Println("Assistant: ", response.Content.Text())
Expand Down
14 changes: 9 additions & 5 deletions examples/stream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ import (
"os"

"github.com/charmbracelet/ai"
"github.com/charmbracelet/ai/providers"
"github.com/charmbracelet/ai/providers/openai"
)

func main() {
provider := providers.NewOpenAiProvider(providers.WithOpenAiAPIKey(os.Getenv("OPENAI_API_KEY")))
provider := openai.New(openai.WithAPIKey(os.Getenv("OPENAI_API_KEY")))
model, err := provider.LanguageModel("gpt-4o")
if err != nil {
fmt.Println(err)
return
os.Exit(1)
}

stream, err := model.Stream(context.Background(), ai.Call{
Expand Down Expand Up @@ -44,11 +44,15 @@ func main() {
})
if err != nil {
fmt.Println(err)
return
os.Exit(1)
}

for chunk := range stream {
data, _ := json.Marshal(chunk)
data, err := json.Marshal(chunk)
if err != nil {
fmt.Println(err)
continue
}
fmt.Println(string(data))
}
}
8 changes: 4 additions & 4 deletions examples/streaming-agent-simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"os"

"github.com/charmbracelet/ai"
"github.com/charmbracelet/ai/providers"
"github.com/charmbracelet/ai/providers/openai"
)

func main() {
Expand All @@ -18,13 +18,13 @@ func main() {
}

// Create provider and model
provider := providers.NewOpenAiProvider(
providers.WithOpenAiAPIKey(apiKey),
provider := openai.New(
openai.WithAPIKey(apiKey),
)
model, err := provider.LanguageModel("gpt-4o-mini")
if err != nil {
fmt.Println(err)
return
os.Exit(1)
}

// Create echo tool using the new type-safe API
Expand Down
4 changes: 2 additions & 2 deletions examples/streaming-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"strings"

"github.com/charmbracelet/ai"
"github.com/charmbracelet/ai/providers"
"github.com/charmbracelet/ai/providers/anthropic"
)

func main() {
Expand All @@ -24,7 +24,7 @@ func main() {
fmt.Println()

// Create OpenAI provider and model
provider := providers.NewAnthropicProvider(providers.WithAnthropicAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
provider := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
model, err := provider.LanguageModel("claude-sonnet-4-20250514")
if err != nil {
fmt.Println(err)
Expand Down
14 changes: 14 additions & 0 deletions internal/jsonext/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package jsonext

import (
"encoding/json"
)

func IsValidJSON[T string | []byte](data T) bool {
if len(data) == 0 { // hot path
return false
}
var m json.RawMessage
err := json.Unmarshal([]byte(data), &m)
return err == nil
}
Loading
Loading