Skip to content

Commit b77e637

Browse files
committed
feat: support structured output for agents
To enable this, LanguageModel required updates to support tool calls as normal, non structured generation methods already do. These changes are fully backwards compatible. Note that for this feature to work, each provider must explicitly support tool calls alongside structured output. This commit only implements it for the OpenAI provider. If a provider lacks support, the agent will only generate structured output without executing tool calls. Related to #118.
1 parent 32045be commit b77e637

4 files changed

Lines changed: 125 additions & 23 deletions

File tree

agent.go

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ type stepExecutionResult struct {
3030
// StopCondition defines a function that determines when an agent should stop executing.
3131
type StopCondition = func(steps []StepResult) bool
3232

33+
type responseGenerator = func(ctx context.Context, model LanguageModel, call Call) (*Response, error)
34+
3335
// StepCountIs returns a stop condition that stops after the specified number of steps.
3436
func StepCountIs(stepCount int) StopCondition {
3537
return func(steps []StepResult) bool {
@@ -304,6 +306,7 @@ type AgentResult struct {
304306
// Agent represents an AI agent that can generate responses and stream responses.
305307
type Agent interface {
306308
Generate(context.Context, AgentCall) (*AgentResult, error)
309+
GenerateObject(context.Context, schema.Schema, AgentCall) (*AgentResult, error)
307310
Stream(context.Context, AgentStreamCall) (*AgentResult, error)
308311
}
309312

@@ -367,13 +370,12 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
367370
return call
368371
}
369372

370-
// Generate implements Agent.
371-
func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
372-
opts = a.prepareCall(opts)
373-
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
374-
if err != nil {
375-
return nil, err
376-
}
373+
func (a *agent) executeLoop(
374+
ctx context.Context,
375+
initialPrompt Prompt,
376+
gen responseGenerator,
377+
opts AgentCall,
378+
) ([]StepResult, error) {
377379
var responseMessages []Message
378380
var steps []StepResult
379381

@@ -446,7 +448,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
446448
retryOptions.OnRetry = opts.OnRetry
447449
retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
448450
result, err := retry(ctx, func() (*Response, error) {
449-
return stepModel.Generate(ctx, Call{
451+
return gen(ctx, stepModel, Call{
450452
Prompt: stepInputMessages,
451453
MaxOutputTokens: opts.MaxOutputTokens,
452454
Temperature: opts.Temperature,
@@ -485,7 +487,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
485487

486488
toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
487489

488-
// Build step content with validated tool calls and tool results. // Provider-executed tool calls are kept as-is.
490+
// Build step content with validated tool calls and tool results. Provider-executed tool calls are kept as-is.
489491
stepContent := []Content{}
490492
toolCallIndex := 0
491493
for _, content := range result.Content {
@@ -528,8 +530,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
528530
}
529531
}
530532

531-
totalUsage := Usage{}
533+
//nolint:nilerr // tool execution failure breaks the loop but does not prevent an answer from being returned
534+
return steps, nil
535+
}
532536

537+
func toAgentResult(steps []StepResult) *AgentResult {
538+
totalUsage := Usage{}
533539
for _, step := range steps {
534540
usage := step.Usage
535541
totalUsage.InputTokens += usage.InputTokens
@@ -540,12 +546,89 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
540546
totalUsage.TotalTokens += usage.TotalTokens
541547
}
542548

543-
agentResult := &AgentResult{
549+
return &AgentResult{
544550
Steps: steps,
545551
Response: steps[len(steps)-1].Response,
546552
TotalUsage: totalUsage,
547553
}
548-
return agentResult, nil
554+
}
555+
556+
// Generate implements Agent.
557+
func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
558+
opts = a.prepareCall(opts)
559+
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
560+
if err != nil {
561+
return nil, err
562+
}
563+
steps, err := a.executeLoop(
564+
ctx,
565+
initialPrompt,
566+
func(ctx context.Context, stepModel LanguageModel, call Call) (*Response, error) {
567+
return stepModel.Generate(ctx, call)
568+
},
569+
opts,
570+
)
571+
if err != nil {
572+
return nil, err
573+
}
574+
575+
return toAgentResult(steps), nil
576+
}
577+
578+
func (a *agent) GenerateObject(ctx context.Context, s schema.Schema, opts AgentCall) (*AgentResult, error) {
579+
opts = a.prepareCall(opts)
580+
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
581+
if err != nil {
582+
return nil, err
583+
}
584+
585+
steps, err := a.executeLoop(
586+
ctx,
587+
initialPrompt,
588+
func(ctx context.Context, model LanguageModel, call Call) (*Response, error) {
589+
res, err := model.GenerateObject(ctx, ObjectCall{
590+
Prompt: call.Prompt,
591+
Schema: s,
592+
MaxOutputTokens: call.MaxOutputTokens,
593+
Temperature: call.Temperature,
594+
TopP: call.TopP,
595+
TopK: call.TopK,
596+
PresencePenalty: call.PresencePenalty,
597+
FrequencyPenalty: call.FrequencyPenalty,
598+
UserAgent: call.UserAgent,
599+
ProviderOptions: call.ProviderOptions,
600+
RepairText: nil,
601+
Tools: call.Tools,
602+
ToolChoice: call.ToolChoice,
603+
})
604+
if err != nil {
605+
return nil, err
606+
}
607+
608+
var content ResponseContent
609+
for _, toolCall := range res.ToolCalls {
610+
content = append(content, toolCall)
611+
}
612+
613+
if res.RawText != "" {
614+
content = append(content, TextContent{Text: res.RawText})
615+
}
616+
617+
return &Response{
618+
Content: content,
619+
FinishReason: res.FinishReason,
620+
Usage: res.Usage,
621+
Warnings: res.Warnings,
622+
ProviderMetadata: res.ProviderMetadata,
623+
}, nil
624+
},
625+
opts,
626+
)
627+
if err != nil {
628+
return nil, err
629+
}
630+
631+
return toAgentResult(steps), nil
549632
}
550633

551634
func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {

object.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ type ObjectCall struct {
4747
ProviderOptions ProviderOptions
4848

4949
RepairText schema.ObjectRepairFunc
50+
51+
Tools []Tool `json:"tools"`
52+
ToolChoice *ToolChoice `json:"tool_choice"`
5053
}
5154

5255
// ObjectResponse represents the response from a structured object generation.
@@ -57,6 +60,7 @@ type ObjectResponse struct {
5760
FinishReason FinishReason
5861
Warnings []CallWarning
5962
ProviderMetadata ProviderMetadata
63+
ToolCalls []ToolCallContent
6064
}
6165

6266
// ObjectStreamPartType indicates the type of stream part.
@@ -99,6 +103,7 @@ type ObjectResult[T any] struct {
99103
FinishReason FinishReason
100104
Warnings []CallWarning
101105
ProviderMetadata ProviderMetadata
106+
ToolCalls []ToolCallContent
102107
}
103108

104109
// StreamObjectResult provides typed access to a streaming object generation result.

object/object.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func Generate[T any](
5151
FinishReason: resp.FinishReason,
5252
Warnings: resp.Warnings,
5353
ProviderMetadata: resp.ProviderMetadata,
54+
ToolCalls: resp.ToolCalls,
5455
}, nil
5556
}
5657

providers/openai/responses_language_model.go

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13251325
PresencePenalty: call.PresencePenalty,
13261326
FrequencyPenalty: call.FrequencyPenalty,
13271327
ProviderOptions: call.ProviderOptions,
1328+
Tools: call.Tools,
1329+
ToolChoice: call.ToolChoice,
13281330
}
13291331

13301332
params, warnings, err := o.prepareParams(fantasyCall)
@@ -1350,8 +1352,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13501352
}
13511353
}
13521354

1353-
// Extract JSON text from response
13541355
var jsonText string
1356+
var toolCalls []fantasy.ToolCallContent
13551357
for _, outputItem := range response.Output {
13561358
if outputItem.Type == "message" {
13571359
for _, contentPart := range outputItem.Content {
@@ -1361,15 +1363,20 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13611363
}
13621364
}
13631365
}
1366+
if outputItem.Type == "function_call" {
1367+
toolCalls = append(toolCalls, fantasy.ToolCallContent{
1368+
ProviderExecuted: false,
1369+
ToolCallID: outputItem.CallID,
1370+
ToolName: outputItem.Name,
1371+
Input: outputItem.Arguments.OfString,
1372+
})
1373+
}
13641374
}
13651375

1366-
if jsonText == "" {
1367-
usage := fantasy.Usage{
1368-
InputTokens: response.Usage.InputTokens,
1369-
OutputTokens: response.Usage.OutputTokens,
1370-
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
1371-
}
1372-
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)
1376+
usage := responsesUsage(*response)
1377+
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)
1378+
1379+
if jsonText == "" && len(toolCalls) == 0 {
13731380
return nil, &fantasy.NoObjectGeneratedError{
13741381
RawText: "",
13751382
ParseError: fmt.Errorf("no text content in response"),
@@ -1378,6 +1385,14 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13781385
}
13791386
}
13801387

1388+
if jsonText == "" && len(toolCalls) > 0 {
1389+
return &fantasy.ObjectResponse{
1390+
Usage: usage,
1391+
FinishReason: finishReason,
1392+
ToolCalls: toolCalls,
1393+
}, nil
1394+
}
1395+
13811396
// Parse and validate
13821397
var obj any
13831398
if call.RepairText != nil {
@@ -1386,9 +1401,6 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13861401
obj, err = schema.ParseAndValidate(jsonText, call.Schema)
13871402
}
13881403

1389-
usage := responsesUsage(*response)
1390-
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)
1391-
13921404
if err != nil {
13931405
// Add usage info to error
13941406
if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
@@ -1405,6 +1417,7 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
14051417
FinishReason: finishReason,
14061418
Warnings: warnings,
14071419
ProviderMetadata: responsesProviderMetadata(response.ID),
1420+
ToolCalls: toolCalls,
14081421
}, nil
14091422
}
14101423

0 commit comments

Comments
 (0)