Skip to content

Commit 215059b

Browse files
Minor fixes (#8)
* chore: remove unused headers * feat: expose provider options * fix: use mapstruct tag * fix: thinking delta * fix: agent prompt * fix: fix reasoning metadata fix: fix reasoning content * fix: test
1 parent 245142f commit 215059b

9 files changed

Lines changed: 99 additions & 131 deletions

File tree

ai/agent.go

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ type AgentCall struct {
144144
PresencePenalty *float64 `json:"presence_penalty"`
145145
FrequencyPenalty *float64 `json:"frequency_penalty"`
146146
ActiveTools []string `json:"active_tools"`
147-
Headers map[string]string
148147
ProviderOptions ProviderOptions
149148
OnRetry OnRetryCallback
150149
MaxRetries *int
@@ -336,10 +335,6 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
336335
maps.Copy(headers, a.settings.headers)
337336
}
338337

339-
if call.Headers != nil {
340-
maps.Copy(headers, call.Headers)
341-
}
342-
call.Headers = headers
343338
return call
344339
}
345340

@@ -420,7 +415,6 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
420415
FrequencyPenalty: opts.FrequencyPenalty,
421416
Tools: preparedTools,
422417
ToolChoice: &stepToolChoice,
423-
Headers: opts.Headers,
424418
ProviderOptions: opts.ProviderOptions,
425419
})
426420
})
@@ -747,7 +741,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
747741
PresencePenalty: opts.PresencePenalty,
748742
FrequencyPenalty: opts.FrequencyPenalty,
749743
ActiveTools: opts.ActiveTools,
750-
Headers: opts.Headers,
751744
ProviderOptions: opts.ProviderOptions,
752745
MaxRetries: opts.MaxRetries,
753746
StopWhen: opts.StopWhen,
@@ -838,7 +831,6 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
838831
FrequencyPenalty: call.FrequencyPenalty,
839832
Tools: preparedTools,
840833
ToolChoice: &stepToolChoice,
841-
Headers: call.Headers,
842834
ProviderOptions: call.ProviderOptions,
843835
}
844836

@@ -994,9 +986,8 @@ func (a *agent) createPrompt(system, prompt string, messages []Message, files ..
994986
if system != "" {
995987
preparedPrompt = append(preparedPrompt, NewSystemMessage(system))
996988
}
997-
998-
preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
999989
preparedPrompt = append(preparedPrompt, messages...)
990+
preparedPrompt = append(preparedPrompt, NewUserMessage(prompt, files...))
1000991
return preparedPrompt, nil
1001992
}
1002993

@@ -1077,6 +1068,11 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
10771068

10781069
activeToolCalls := make(map[string]*ToolCallContent)
10791070
activeTextContent := make(map[string]string)
1071+
type reasoningContent struct {
1072+
content string
1073+
options ProviderMetadata
1074+
}
1075+
activeReasoningContent := make(map[string]reasoningContent)
10801076

10811077
// Process stream parts
10821078
for part := range stream {
@@ -1134,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11341130
}
11351131

11361132
case StreamPartTypeReasoningStart:
1137-
activeTextContent[part.ID] = ""
1133+
activeReasoningContent[part.ID] = reasoningContent{content: ""}
11381134
if opts.OnReasoningStart != nil {
11391135
err := opts.OnReasoningStart(part.ID)
11401136
if err != nil {
@@ -1143,8 +1139,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11431139
}
11441140

11451141
case StreamPartTypeReasoningDelta:
1146-
if _, exists := activeTextContent[part.ID]; exists {
1147-
activeTextContent[part.ID] += part.Delta
1142+
if active, exists := activeReasoningContent[part.ID]; exists {
1143+
active.content += part.Delta
1144+
active.options = part.ProviderMetadata
1145+
activeReasoningContent[part.ID] = active
11481146
}
11491147
if opts.OnReasoningDelta != nil {
11501148
err := opts.OnReasoningDelta(part.ID, part.Delta)
@@ -1154,21 +1152,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11541152
}
11551153

11561154
case StreamPartTypeReasoningEnd:
1157-
if text, exists := activeTextContent[part.ID]; exists {
1158-
stepContent = append(stepContent, ReasoningContent{
1159-
Text: text,
1160-
ProviderMetadata: part.ProviderMetadata,
1161-
})
1155+
if active, exists := activeReasoningContent[part.ID]; exists {
1156+
content := ReasoningContent{
1157+
Text: active.content,
1158+
ProviderMetadata: active.options,
1159+
}
1160+
stepContent = append(stepContent, content)
11621161
if opts.OnReasoningEnd != nil {
1163-
err := opts.OnReasoningEnd(part.ID, ReasoningContent{
1164-
Text: text,
1165-
ProviderMetadata: part.ProviderMetadata,
1166-
})
1162+
err := opts.OnReasoningEnd(part.ID, content)
11671163
if err != nil {
11681164
return StepResult{}, false, err
11691165
}
11701166
}
1171-
delete(activeTextContent, part.ID)
1167+
delete(activeReasoningContent, part.ID)
11721168
}
11731169

11741170
case StreamPartTypeToolInputStart:

ai/agent_test.go

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -563,42 +563,6 @@ func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
563563
require.NotNil(t, result)
564564
}
565565

566-
// Test options.headers
567-
func TestAgent_Generate_OptionsHeaders(t *testing.T) {
568-
t.Parallel()
569-
570-
model := &mockLanguageModel{
571-
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
572-
// Verify headers are passed
573-
require.Equal(t, map[string]string{
574-
"custom-request-header": "request-header-value",
575-
}, call.Headers)
576-
577-
return &Response{
578-
Content: []Content{
579-
TextContent{Text: "Hello, world!"},
580-
},
581-
Usage: Usage{
582-
InputTokens: 3,
583-
OutputTokens: 10,
584-
TotalTokens: 13,
585-
},
586-
FinishReason: FinishReasonStop,
587-
}, nil
588-
},
589-
}
590-
591-
agent := NewAgent(model)
592-
result, err := agent.Generate(context.Background(), AgentCall{
593-
Prompt: "test-input",
594-
Headers: map[string]string{"custom-request-header": "request-header-value"},
595-
})
596-
597-
require.NoError(t, err)
598-
require.NotNil(t, result)
599-
require.Equal(t, "Hello, world!", result.Response.Content.Text())
600-
}
601-
602566
// Test options.activeTools filtering
603567
func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
604568
t.Parallel()

ai/model.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,15 @@ func SpecificToolChoice(name string) ToolChoice {
176176
}
177177

178178
type Call struct {
179-
Prompt Prompt `json:"prompt"`
180-
MaxOutputTokens *int64 `json:"max_output_tokens"`
181-
Temperature *float64 `json:"temperature"`
182-
TopP *float64 `json:"top_p"`
183-
TopK *int64 `json:"top_k"`
184-
PresencePenalty *float64 `json:"presence_penalty"`
185-
FrequencyPenalty *float64 `json:"frequency_penalty"`
186-
Tools []Tool `json:"tools"`
187-
ToolChoice *ToolChoice `json:"tool_choice"`
188-
Headers map[string]string `json:"headers"`
179+
Prompt Prompt `json:"prompt"`
180+
MaxOutputTokens *int64 `json:"max_output_tokens"`
181+
Temperature *float64 `json:"temperature"`
182+
TopP *float64 `json:"top_p"`
183+
TopK *int64 `json:"top_k"`
184+
PresencePenalty *float64 `json:"presence_penalty"`
185+
FrequencyPenalty *float64 `json:"frequency_penalty"`
186+
Tools []Tool `json:"tools"`
187+
ToolChoice *ToolChoice `json:"tool_choice"`
189188

190189
// for provider specific options, the key is the provider id
191190
ProviderOptions ProviderOptions `json:"provider_options"`

anthropic/anthropic.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func (a languageModel) Provider() string {
118118

119119
func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams, []ai.CallWarning, error) {
120120
params := &anthropic.MessageNewParams{}
121-
providerOptions := &providerOptions{}
121+
providerOptions := &ProviderOptions{}
122122
if v, ok := call.ProviderOptions["anthropic"]; ok {
123123
err := ai.ParseOptions(v, providerOptions)
124124
if err != nil {
@@ -217,21 +217,21 @@ func (a languageModel) prepareParams(call ai.Call) (*anthropic.MessageNewParams,
217217
return params, warnings, nil
218218
}
219219

220-
func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOptions {
220+
func getCacheControl(providerOptions ai.ProviderOptions) *CacheControlProviderOptions {
221221
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
222222
if cacheControl, ok := anthropicOptions["cache_control"]; ok {
223223
if cc, ok := cacheControl.(map[string]any); ok {
224-
cacheControlOption := &cacheControlProviderOptions{}
224+
cacheControlOption := &CacheControlProviderOptions{}
225225
err := ai.ParseOptions(cc, cacheControlOption)
226-
if err != nil {
226+
if err == nil {
227227
return cacheControlOption
228228
}
229229
}
230230
} else if cacheControl, ok := anthropicOptions["cacheControl"]; ok {
231231
if cc, ok := cacheControl.(map[string]any); ok {
232-
cacheControlOption := &cacheControlProviderOptions{}
232+
cacheControlOption := &CacheControlProviderOptions{}
233233
err := ai.ParseOptions(cc, cacheControlOption)
234-
if err != nil {
234+
if err == nil {
235235
return cacheControlOption
236236
}
237237
}
@@ -240,11 +240,11 @@ func getCacheControl(providerOptions ai.ProviderOptions) *cacheControlProviderOp
240240
return nil
241241
}
242242

243-
func getReasoningMetadata(providerOptions ai.ProviderOptions) *reasoningMetadata {
243+
func getReasoningMetadata(providerOptions ai.ProviderOptions) *ReasoningMetadata {
244244
if anthropicOptions, ok := providerOptions["anthropic"]; ok {
245-
reasoningMetadata := &reasoningMetadata{}
245+
reasoningMetadata := &ReasoningMetadata{}
246246
err := ai.ParseOptions(anthropicOptions, reasoningMetadata)
247-
if err != nil {
247+
if err == nil {
248248
return reasoningMetadata
249249
}
250250
}
@@ -837,7 +837,7 @@ func (a languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamRespo
837837
if !yield(ai.StreamPart{
838838
Type: ai.StreamPartTypeReasoningDelta,
839839
ID: fmt.Sprintf("%d", chunk.Index),
840-
Delta: chunk.Delta.Text,
840+
Delta: chunk.Delta.Thinking,
841841
}) {
842842
return
843843
}

anthropic/provider_options.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
package anthropic
22

3-
type providerOptions struct {
4-
SendReasoning *bool `json:"send_reasoning,omitempty"`
5-
Thinking *thinkingProviderOption `json:"thinking,omitempty"`
6-
DisableParallelToolUse *bool `json:"disable_parallel_tool_use,omitempty"`
3+
type ProviderOptions struct {
4+
SendReasoning *bool `mapstructure:"send_reasoning,omitempty"`
5+
Thinking *ThinkingProviderOption `mapstructure:"thinking,omitempty"`
6+
DisableParallelToolUse *bool `mapstructure:"disable_parallel_tool_use,omitempty"`
77
}
88

9-
type thinkingProviderOption struct {
10-
BudgetTokens int64 `json:"budget_tokens"`
9+
type ThinkingProviderOption struct {
10+
BudgetTokens int64 `mapstructure:"budget_tokens"`
1111
}
1212

13-
type reasoningMetadata struct {
14-
Signature string `json:"signature"`
15-
RedactedData string `json:"redacted_data"`
13+
type ReasoningMetadata struct {
14+
Signature string `mapstructure:"signature"`
15+
RedactedData string `mapstructure:"redacted_data"`
1616
}
1717

18-
type cacheControlProviderOptions struct {
19-
Type string `json:"type"`
18+
type CacheControlProviderOptions struct {
19+
Type string `mapstructure:"type"`
2020
}

cspell.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"language": "en",
3+
"version": "0.2",
4+
"flagWords": [],
5+
"words": [
6+
"mapstructure",
7+
"mapstructure"
8+
]
9+
}

openai/openai.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ func (o languageModel) Provider() string {
145145
func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewParams, []ai.CallWarning, error) {
146146
params := &openai.ChatCompletionNewParams{}
147147
messages, warnings := toPrompt(call.Prompt)
148-
providerOptions := &providerOptions{}
148+
providerOptions := &ProviderOptions{}
149149
if v, ok := call.ProviderOptions["openai"]; ok {
150150
err := ai.ParseOptions(v, providerOptions)
151151
if err != nil {
@@ -239,13 +239,13 @@ func (o languageModel) prepareParams(call ai.Call) (*openai.ChatCompletionNewPar
239239

240240
if providerOptions.ReasoningEffort != nil {
241241
switch *providerOptions.ReasoningEffort {
242-
case reasoningEffortMinimal:
242+
case ReasoningEffortMinimal:
243243
params.ReasoningEffort = shared.ReasoningEffortMinimal
244-
case reasoningEffortLow:
244+
case ReasoningEffortLow:
245245
params.ReasoningEffort = shared.ReasoningEffortLow
246-
case reasoningEffortMedium:
246+
case ReasoningEffortMedium:
247247
params.ReasoningEffort = shared.ReasoningEffortMedium
248-
case reasoningEffortHigh:
248+
case ReasoningEffortHigh:
249249
params.ReasoningEffort = shared.ReasoningEffortHigh
250250
default:
251251
return nil, nil, fmt.Errorf("reasoning model `%s` not supported", *providerOptions.ReasoningEffort)

0 commit comments

Comments
 (0)