Skip to content

Commit f4a5e15

Browse files
committed
feat: support structured output with tools/agents
1 parent 32045be commit f4a5e15

4 files changed

Lines changed: 124 additions & 23 deletions

File tree

agent.go

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ type stepExecutionResult struct {
3030
// StopCondition defines a function that determines when an agent should stop executing.
3131
type StopCondition = func(steps []StepResult) bool
3232

33+
type responseGenerator = func(ctx context.Context, model LanguageModel, call Call) (*Response, error)
34+
3335
// StepCountIs returns a stop condition that stops after the specified number of steps.
3436
func StepCountIs(stepCount int) StopCondition {
3537
return func(steps []StepResult) bool {
@@ -304,6 +306,7 @@ type AgentResult struct {
304306
// Agent represents an AI agent that can generate responses and stream responses.
305307
type Agent interface {
306308
Generate(context.Context, AgentCall) (*AgentResult, error)
309+
GenerateObject(context.Context, schema.Schema, AgentCall) (*AgentResult, error)
307310
Stream(context.Context, AgentStreamCall) (*AgentResult, error)
308311
}
309312

@@ -367,13 +370,12 @@ func (a *agent) prepareCall(call AgentCall) AgentCall {
367370
return call
368371
}
369372

370-
// Generate implements Agent.
371-
func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
372-
opts = a.prepareCall(opts)
373-
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
374-
if err != nil {
375-
return nil, err
376-
}
373+
func (a *agent) executeLoop(
374+
ctx context.Context,
375+
initialPrompt Prompt,
376+
gen responseGenerator,
377+
opts AgentCall,
378+
) ([]StepResult, error) {
377379
var responseMessages []Message
378380
var steps []StepResult
379381

@@ -446,7 +448,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
446448
retryOptions.OnRetry = opts.OnRetry
447449
retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions)
448450
result, err := retry(ctx, func() (*Response, error) {
449-
return stepModel.Generate(ctx, Call{
451+
return gen(ctx, stepModel, Call{
450452
Prompt: stepInputMessages,
451453
MaxOutputTokens: opts.MaxOutputTokens,
452454
Temperature: opts.Temperature,
@@ -485,7 +487,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
485487

486488
toolResults, err := a.executeTools(ctx, stepTools, stepExecProviderTools, stepToolCalls, nil)
487489

488-
// Build step content with validated tool calls and tool results. // Provider-executed tool calls are kept as-is.
490+
// Build step content with validated tool calls and tool results. Provider-executed tool calls are kept as-is.
489491
stepContent := []Content{}
490492
toolCallIndex := 0
491493
for _, content := range result.Content {
@@ -528,8 +530,11 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
528530
}
529531
}
530532

531-
totalUsage := Usage{}
533+
return steps, nil
534+
}
532535

536+
func toAgentResult(steps []StepResult) *AgentResult {
537+
totalUsage := Usage{}
533538
for _, step := range steps {
534539
usage := step.Usage
535540
totalUsage.InputTokens += usage.InputTokens
@@ -540,12 +545,89 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
540545
totalUsage.TotalTokens += usage.TotalTokens
541546
}
542547

543-
agentResult := &AgentResult{
548+
return &AgentResult{
544549
Steps: steps,
545550
Response: steps[len(steps)-1].Response,
546551
TotalUsage: totalUsage,
547552
}
548-
return agentResult, nil
553+
}
554+
555+
// Generate implements Agent.
556+
func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, error) {
557+
opts = a.prepareCall(opts)
558+
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
559+
if err != nil {
560+
return nil, err
561+
}
562+
steps, err := a.executeLoop(
563+
ctx,
564+
initialPrompt,
565+
func(ctx context.Context, stepModel LanguageModel, call Call) (*Response, error) {
566+
return stepModel.Generate(ctx, call)
567+
},
568+
opts,
569+
)
570+
if err != nil {
571+
return nil, err
572+
}
573+
574+
return toAgentResult(steps), nil
575+
}
576+
577+
func (a *agent) GenerateObject(ctx context.Context, s schema.Schema, opts AgentCall) (*AgentResult, error) {
578+
opts = a.prepareCall(opts)
579+
initialPrompt, err := a.createPrompt(a.settings.systemPrompt, opts.Prompt, opts.Messages, opts.Files...)
580+
if err != nil {
581+
return nil, err
582+
}
583+
584+
steps, err := a.executeLoop(
585+
ctx,
586+
initialPrompt,
587+
func(ctx context.Context, model LanguageModel, call Call) (*Response, error) {
588+
res, err := model.GenerateObject(ctx, ObjectCall{
589+
Prompt: call.Prompt,
590+
Schema: s,
591+
MaxOutputTokens: call.MaxOutputTokens,
592+
Temperature: call.Temperature,
593+
TopP: call.TopP,
594+
TopK: call.TopK,
595+
PresencePenalty: call.PresencePenalty,
596+
FrequencyPenalty: call.FrequencyPenalty,
597+
UserAgent: call.UserAgent,
598+
ProviderOptions: call.ProviderOptions,
599+
RepairText: nil,
600+
Tools: call.Tools,
601+
ToolChoice: call.ToolChoice,
602+
})
603+
if err != nil {
604+
return nil, err
605+
}
606+
607+
var content ResponseContent
608+
for _, toolCall := range res.ToolCalls {
609+
content = append(content, toolCall)
610+
}
611+
612+
if res.RawText != "" {
613+
content = append(content, TextContent{Text: res.RawText})
614+
}
615+
616+
return &Response{
617+
Content: content,
618+
FinishReason: res.FinishReason,
619+
Usage: res.Usage,
620+
Warnings: res.Warnings,
621+
ProviderMetadata: res.ProviderMetadata,
622+
}, nil
623+
},
624+
opts,
625+
)
626+
if err != nil {
627+
return nil, err
628+
}
629+
630+
return toAgentResult(steps), nil
549631
}
550632

551633
func isStopConditionMet(conditions []StopCondition, steps []StepResult) bool {

object.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ type ObjectCall struct {
4747
ProviderOptions ProviderOptions
4848

4949
RepairText schema.ObjectRepairFunc
50+
51+
Tools []Tool `json:"tools"`
52+
ToolChoice *ToolChoice `json:"tool_choice"`
5053
}
5154

5255
// ObjectResponse represents the response from a structured object generation.
@@ -57,6 +60,7 @@ type ObjectResponse struct {
5760
FinishReason FinishReason
5861
Warnings []CallWarning
5962
ProviderMetadata ProviderMetadata
63+
ToolCalls []ToolCallContent
6064
}
6165

6266
// ObjectStreamPartType indicates the type of stream part.
@@ -99,6 +103,7 @@ type ObjectResult[T any] struct {
99103
FinishReason FinishReason
100104
Warnings []CallWarning
101105
ProviderMetadata ProviderMetadata
106+
ToolCalls []ToolCallContent
102107
}
103108

104109
// StreamObjectResult provides typed access to a streaming object generation result.

object/object.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func Generate[T any](
5151
FinishReason: resp.FinishReason,
5252
Warnings: resp.Warnings,
5353
ProviderMetadata: resp.ProviderMetadata,
54+
ToolCalls: resp.ToolCalls,
5455
}, nil
5556
}
5657

providers/openai/responses_language_model.go

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13251325
PresencePenalty: call.PresencePenalty,
13261326
FrequencyPenalty: call.FrequencyPenalty,
13271327
ProviderOptions: call.ProviderOptions,
1328+
Tools: call.Tools,
1329+
ToolChoice: call.ToolChoice,
13281330
}
13291331

13301332
params, warnings, err := o.prepareParams(fantasyCall)
@@ -1350,8 +1352,8 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13501352
}
13511353
}
13521354

1353-
// Extract JSON text from response
13541355
var jsonText string
1356+
var toolCalls []fantasy.ToolCallContent
13551357
for _, outputItem := range response.Output {
13561358
if outputItem.Type == "message" {
13571359
for _, contentPart := range outputItem.Content {
@@ -1361,15 +1363,20 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13611363
}
13621364
}
13631365
}
1366+
if outputItem.Type == "function_call" {
1367+
toolCalls = append(toolCalls, fantasy.ToolCallContent{
1368+
ProviderExecuted: false,
1369+
ToolCallID: outputItem.CallID,
1370+
ToolName: outputItem.Name,
1371+
Input: outputItem.Arguments.OfString,
1372+
})
1373+
}
13641374
}
13651375

1366-
if jsonText == "" {
1367-
usage := fantasy.Usage{
1368-
InputTokens: response.Usage.InputTokens,
1369-
OutputTokens: response.Usage.OutputTokens,
1370-
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
1371-
}
1372-
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)
1376+
usage := responsesUsage(*response)
1377+
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)
1378+
1379+
if jsonText == "" && len(toolCalls) == 0 {
13731380
return nil, &fantasy.NoObjectGeneratedError{
13741381
RawText: "",
13751382
ParseError: fmt.Errorf("no text content in response"),
@@ -1378,6 +1385,14 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13781385
}
13791386
}
13801387

1388+
if jsonText == "" && len(toolCalls) > 0 {
1389+
return &fantasy.ObjectResponse{
1390+
Usage: usage,
1391+
FinishReason: finishReason,
1392+
ToolCalls: toolCalls,
1393+
}, nil
1394+
}
1395+
13811396
// Parse and validate
13821397
var obj any
13831398
if call.RepairText != nil {
@@ -1386,9 +1401,6 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
13861401
obj, err = schema.ParseAndValidate(jsonText, call.Schema)
13871402
}
13881403

1389-
usage := responsesUsage(*response)
1390-
finishReason := mapResponsesFinishReason(response.IncompleteDetails.Reason, false)
1391-
13921404
if err != nil {
13931405
// Add usage info to error
13941406
if nogErr, ok := err.(*fantasy.NoObjectGeneratedError); ok {
@@ -1405,6 +1417,7 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context,
14051417
FinishReason: finishReason,
14061418
Warnings: warnings,
14071419
ProviderMetadata: responsesProviderMetadata(response.ID),
1420+
ToolCalls: toolCalls,
14081421
}, nil
14091422
}
14101423

0 commit comments

Comments
 (0)