Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions ai/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ type AgentCall struct {
PresencePenalty *float64 `json:"presence_penalty"`
FrequencyPenalty *float64 `json:"frequency_penalty"`
ActiveTools []string `json:"active_tools"`
Headers map[string]string
ProviderOptions ProviderOptions
OnRetry OnRetryCallback
MaxRetries *int
Expand Down Expand Up @@ -336,10 +335,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
maps.Copy(headers, a.settings.headers)
}

if call.Headers != nil {
maps.Copy(headers, call.Headers)
}
call.Headers = headers
return call
}

Expand Down Expand Up @@ -420,7 +415,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
FrequencyPenalty: opts.FrequencyPenalty,
Tools: preparedTools,
ToolChoice: &stepToolChoice,
Headers: opts.Headers,
ProviderOptions: opts.ProviderOptions,
})
})
Expand Down Expand Up @@ -747,7 +741,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
PresencePenalty: opts.PresencePenalty,
FrequencyPenalty: opts.FrequencyPenalty,
ActiveTools: opts.ActiveTools,
Headers: opts.Headers,
ProviderOptions: opts.ProviderOptions,
MaxRetries: opts.MaxRetries,
StopWhen: opts.StopWhen,
Expand Down Expand Up @@ -838,7 +831,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
FrequencyPenalty: call.FrequencyPenalty,
Tools: preparedTools,
ToolChoice: &stepToolChoice,
Headers: call.Headers,
ProviderOptions: call.ProviderOptions,
}

Expand Down Expand Up @@ -994,9 +986,8 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files ..
if system != "" {
preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
}

preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
preparedPrompt = append(preparedPrompt, messages...)
preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
return preparedPrompt, nil
}

Expand Down Expand Up @@ -1077,6 +1068,11 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op

activeToolCalls := make(map[string]*ToolCallContent)
activeTextContent := make(map[string]string)
type reasoningContent struct {
content string
options ProviderMetadata
}
activeReasoningContent := make(map[string]reasoningContent)

// Process stream parts
for part := range stream {
Expand Down Expand Up @@ -1134,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}

case StreamPartTypeReasoningStart:
activeTextContent[part.ID] = ""
activeReasoningContent[part.ID] = reasoningContent{content: ""}
if opts.OnReasoningStart != nil {
err := opts.OnReasoningStart(part.ID)
if err != nil {
Expand All @@ -1143,8 +1139,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}

case StreamPartTypeReasoningDelta:
if _, exists := activeTextContent[part.ID]; exists {
activeTextContent[part.ID] += part.Delta
if active, exists := activeReasoningContent[part.ID]; exists {
active.content += part.Delta
active.options = part.ProviderMetadata
activeReasoningContent[part.ID] = active
}
if opts.OnReasoningDelta != nil {
err := opts.OnReasoningDelta(part.ID, part.Delta)
Expand All @@ -1154,21 +1152,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}

case StreamPartTypeReasoningEnd:
if text, exists := activeTextContent[part.ID]; exists {
stepContent = append(stepContent, ReasoningContent{
Text: text,
ProviderMetadata: part.ProviderMetadata,
})
if active, exists := activeReasoningContent[part.ID]; exists {
content := ReasoningContent{
Text: active.content,
ProviderMetadata: active.options,
}
stepContent = append(stepContent, content)
if opts.OnReasoningEnd != nil {
err := opts.OnReasoningEnd(part.ID, ReasoningContent{
Text: text,
ProviderMetadata: part.ProviderMetadata,
})
err := opts.OnReasoningEnd(part.ID, content)
if err != nil {
return StepResult{}, false, err
}
}
delete(activeTextContent, part.ID)
delete(activeReasoningContent, part.ID)
}

case StreamPartTypeToolInputStart:
Expand Down
36 changes: 0 additions & 36 deletions ai/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,42 +563,6 @@ func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
require.NotNil(t, result)
}

// Test options.headers
func TestAgent_Generate_OptionsHeaders(t *testing.T) {
t.Parallel()

model := &mockLanguageModel{
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
// Verify headers are passed
require.Equal(t, map[string]string{
"custom-request-header": "request-header-value",
}, call.Headers)

return &Response{
Content: []Content{
TextContent{Text: "Hello, world!"},
},
Usage: Usage{
InputTokens: 3,
OutputTokens: 10,
TotalTokens: 13,
},
FinishReason: FinishReasonStop,
}, nil
},
}

agent := NewAgent(model)
result, err := agent.Generate(context.Background(), AgentCall{
Prompt: "test-input",
Headers: map[string]string{"custom-request-header": "request-header-value"},
})

require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "Hello, world!", result.Response.Content.Text())
}

// Test options.activeTools filtering
func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
t.Parallel()
Expand Down
19 changes: 9 additions & 10 deletions ai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,15 @@ func SpecificToolChoice(name string) ToolChoice {
}

type Call struct {
Prompt Prompt `json:"prompt"`
MaxOutputTokens *int64 `json:"max_output_tokens"`
Temperature *float64 `json:"temperature"`
TopP *float64 `json:"top_p"`
TopK *int64 `json:"top_k"`
PresencePenalty *float64 `json:"presence_penalty"`
FrequencyPenalty *float64 `json:"frequency_penalty"`
Tools []Tool `json:"tools"`
ToolChoice *ToolChoice `json:"tool_choice"`
Headers map[string]string `json:"headers"`
Prompt Prompt `json:"prompt"`
MaxOutputTokens *int64 `json:"max_output_tokens"`
Temperature *float64 `json:"temperature"`
TopP *float64 `json:"top_p"`
TopK *int64 `json:"top_k"`
PresencePenalty *float64 `json:"presence_penalty"`
FrequencyPenalty *float64 `json:"frequency_penalty"`
Tools []Tool `json:"tools"`
ToolChoice *ToolChoice `json:"tool_choice"`

// for provider specific options, the key is the provider id
ProviderOptions ProviderOptions `json:"provider_options"`
Expand Down
20 changes: 10 additions & 10 deletions anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (a languageModel) Provider() string {

func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
params := &anthropic.MessageNewParams{}
providerOptions := &providerOptions{}
providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions["anthropic"]; ok {
err := ai.ParseOptions(v, providerOptions)
if err != nil {
Expand Down Expand Up @@ -217,21 +217,21 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
return params, warnings, nil
}

func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOptions {
func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions {
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
if cacheControl, ok := anthropicOptions["cache_control"]; ok {
if cc, ok := cacheControl.(map[string]any); ok {
cacheControlOption := &cacheControlProviderOptions{}
cacheControlOption := &CacheControlProviderOptions{}
err := ai.ParseOptions(cc, cacheControlOption)
if err != nil {
if err == nil {
return cacheControlOption
}
}
} else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
if cc, ok := cacheControl.(map[string]any); ok {
cacheControlOption := &cacheControlProviderOptions{}
cacheControlOption := &CacheControlProviderOptions{}
err := ai.ParseOptions(cc, cacheControlOption)
if err != nil {
if err == nil {
return cacheControlOption
}
}
Expand All @@ -240,11 +240,11 @@ func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOp
return nil
}

func getReasoningMetadata(providerOptions ai.ProviderOptions) *reasoningMetadata {
func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
reasoningMetadata := &reasoningMetadata{}
reasoningMetadata := &ReasoningMetadata{}
err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
if err != nil {
if err == nil {
return reasoningMetadata
}
}
Expand Down Expand Up @@ -837,7 +837,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
if !yield(ai.StreamPart{
Type: ai.StreamPartTypeReasoningDelta,
ID: fmt.Sprintf("%d", chunk.Index),
Delta: chunk.Delta.Text,
Delta: chunk.Delta.Thinking,
}) {
return
}
Expand Down
22 changes: 11 additions & 11 deletions anthropic/provider_options.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
package anthropic

type providerOptions struct {
SendReasoning *bool `json:"send_reasoning,omitempty"`
Thinking *thinkingProviderOption `json:"thinking,omitempty"`
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
type ProviderOptions struct {
SendReasoning *bool `mapstructure:"send_reasoning,omitempty"`
Thinking *ThinkingProviderOption `mapstructure:"thinking,omitempty"`
DisableParallelToolUse *bool `mapstructure:"disable_parallel_tool_use,omitempty"`
}

type thinkingProviderOption struct {
BudgetTokens int64 `json:"budget_tokens"`
type ThinkingProviderOption struct {
BudgetTokens int64 `mapstructure:"budget_tokens"`
}

type reasoningMetadata struct {
Signature string `json:"signature"`
RedactedData string `json:"redacted_data"`
type ReasoningMetadata struct {
Signature string `mapstructure:"signature"`
RedactedData string `mapstructure:"redacted_data"`
}

type cacheControlProviderOptions struct {
Type string `json:"type"`
type CacheControlProviderOptions struct {
Type string `mapstructure:"type"`
}
9 changes: 9 additions & 0 deletions cspell.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"language": "en",
"version": "0.2",
"flagWords": [],
"words": [
"mapstructure",
"mapstructure"
]
}
10 changes: 5 additions & 5 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (o languageModel) Provider() string {
func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
params := &openai.ChatCompletionNewParams{}
messages, warnings := toPrompt(call.Prompt)
providerOptions := &providerOptions{}
providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions["openai"]; ok {
err := ai.ParseOptions(v, providerOptions)
if err != nil {
Expand Down Expand Up @@ -239,13 +239,13 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar

if providerOptions.ReasoningEffort != nil {
switch *providerOptions.ReasoningEffort {
case reasoningEffortMinimal:
case ReasoningEffortMinimal:
params.ReasoningEffort = shared.ReasoningEffortMinimal
case reasoningEffortLow:
case ReasoningEffortLow:
params.ReasoningEffort = shared.ReasoningEffortLow
case reasoningEffortMedium:
case ReasoningEffortMedium:
params.ReasoningEffort = shared.ReasoningEffortMedium
case reasoningEffortHigh:
case ReasoningEffortHigh:
params.ReasoningEffort = shared.ReasoningEffortHigh
default:
return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)
Expand Down
Loading
Loading