Skip to content

Commit 97682b9

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 627bbc5 + 31802ee commit 97682b9

14 files changed

Lines changed: 1053 additions & 252 deletions

agent.go

Lines changed: 10 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -636,94 +636,11 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
636636
results := make([]ToolResultContent, 0, len(toolCalls))
637637

638638
for _, toolCall := range toolCalls {
639-
// Skip invalid tool calls - create error result
640-
if toolCall.Invalid {
641-
result := ToolResultContent{
642-
ToolCallID: toolCall.ToolCallID,
643-
ToolName: toolCall.ToolName,
644-
Result: ToolResultOutputContentError{
645-
Error: toolCall.ValidationError,
646-
},
647-
ProviderExecuted: false,
648-
}
649-
results = append(results, result)
650-
if toolResultCallback != nil {
651-
if err := toolResultCallback(result); err != nil {
652-
return nil, err
653-
}
654-
}
655-
continue
656-
}
657-
658-
tool, exists := toolMap[toolCall.ToolName]
659-
if !exists {
660-
result := ToolResultContent{
661-
ToolCallID: toolCall.ToolCallID,
662-
ToolName: toolCall.ToolName,
663-
Result: ToolResultOutputContentError{
664-
Error: errors.New("Error: Tool not found: " + toolCall.ToolName),
665-
},
666-
ProviderExecuted: false,
667-
}
668-
results = append(results, result)
669-
if toolResultCallback != nil {
670-
if err := toolResultCallback(result); err != nil {
671-
return nil, err
672-
}
673-
}
674-
continue
675-
}
676-
677-
// Execute the tool
678-
toolResult, err := tool.Run(ctx, ToolCall{
679-
ID: toolCall.ToolCallID,
680-
Name: toolCall.ToolName,
681-
Input: toolCall.Input,
682-
})
683-
if err != nil {
684-
result := ToolResultContent{
685-
ToolCallID: toolCall.ToolCallID,
686-
ToolName: toolCall.ToolName,
687-
Result: ToolResultOutputContentError{
688-
Error: err,
689-
},
690-
ClientMetadata: toolResult.Metadata,
691-
ProviderExecuted: false,
692-
}
693-
if toolResultCallback != nil {
694-
if cbErr := toolResultCallback(result); cbErr != nil {
695-
return nil, cbErr
696-
}
697-
}
698-
return nil, err
699-
}
700-
701-
var result ToolResultContent
702-
if toolResult.IsError {
703-
result = ToolResultContent{
704-
ToolCallID: toolCall.ToolCallID,
705-
ToolName: toolCall.ToolName,
706-
Result: ToolResultOutputContentError{
707-
Error: errors.New(toolResult.Content),
708-
},
709-
ClientMetadata: toolResult.Metadata,
710-
ProviderExecuted: false,
711-
}
712-
} else {
713-
result = ToolResultContent{
714-
ToolCallID: toolCall.ToolCallID,
715-
ToolName: toolCall.ToolName,
716-
Result: ToolResultOutputContentText{
717-
Text: toolResult.Content,
718-
},
719-
ClientMetadata: toolResult.Metadata,
720-
ProviderExecuted: false,
721-
}
722-
}
639+
result, isCriticalError := a.executeSingleTool(ctx, toolMap, toolCall, toolResultCallback)
723640
results = append(results, result)
724-
if toolResultCallback != nil {
725-
if err := toolResultCallback(result); err != nil {
726-
return nil, err
641+
if isCriticalError {
642+
if errorResult, ok := result.Result.(ToolResultOutputContentError); ok && errorResult.Error != nil {
643+
return nil, errorResult.Error
727644
}
728645
}
729646
}
@@ -775,7 +692,6 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT
775692
if toolResultCallback != nil {
776693
_ = toolResultCallback(result)
777694
}
778-
// This is a critical error - tool.Run() failed
779695
return result, true
780696
}
781697

@@ -784,6 +700,12 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT
784700
result.Result = ToolResultOutputContentError{
785701
Error: errors.New(toolResult.Content),
786702
}
703+
} else if toolResult.Type == "image" || toolResult.Type == "media" {
704+
result.Result = ToolResultOutputContentMedia{
705+
Data: string(toolResult.Data),
706+
MediaType: toolResult.MediaType,
707+
Text: toolResult.Content,
708+
}
787709
} else {
788710
result.Result = ToolResultOutputContentText{
789711
Text: toolResult.Content,
@@ -792,7 +714,6 @@ func (a *agent) executeSingleTool(ctx context.Context, toolMap map[string]AgentT
792714
if toolResultCallback != nil {
793715
_ = toolResultCallback(result)
794716
}
795-
// Not a critical error - tool ran successfully (even if it reported an error state)
796717
return result, false
797718
}
798719

agent_test.go

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,3 +1535,236 @@ func TestToolCallRepair(t *testing.T) {
15351535
require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
15361536
})
15371537
}
1538+
1539+
// Test media and image tool responses
1540+
func TestAgent_MediaToolResponses(t *testing.T) {
1541+
t.Parallel()
1542+
1543+
imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
1544+
audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
1545+
1546+
t.Run("Image tool response", func(t *testing.T) {
1547+
t.Parallel()
1548+
1549+
imageTool := &mockTool{
1550+
name: "generate_image",
1551+
description: "Generates an image",
1552+
executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1553+
return NewImageResponse(imageData, "image/png"), nil
1554+
},
1555+
}
1556+
1557+
model := &mockLanguageModel{
1558+
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1559+
if len(call.Prompt) == 1 {
1560+
// First call - request image tool
1561+
return &Response{
1562+
Content: []Content{
1563+
ToolCallContent{
1564+
ToolCallID: "img-1",
1565+
ToolName: "generate_image",
1566+
Input: `{}`,
1567+
},
1568+
},
1569+
Usage: Usage{TotalTokens: 10},
1570+
FinishReason: FinishReasonToolCalls,
1571+
}, nil
1572+
}
1573+
// Second call - after tool execution
1574+
return &Response{
1575+
Content: []Content{TextContent{Text: "Image generated"}},
1576+
Usage: Usage{TotalTokens: 20},
1577+
FinishReason: FinishReasonStop,
1578+
}, nil
1579+
},
1580+
}
1581+
1582+
agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1583+
1584+
result, err := agent.Generate(context.Background(), AgentCall{
1585+
Prompt: "Generate an image",
1586+
})
1587+
1588+
require.NoError(t, err)
1589+
require.NotNil(t, result)
1590+
require.Len(t, result.Steps, 2) // Tool call step + final response
1591+
1592+
// Check tool results in first step
1593+
toolResults := result.Steps[0].Content.ToolResults()
1594+
require.Len(t, toolResults, 1)
1595+
1596+
mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1597+
require.True(t, ok, "Expected media result")
1598+
require.Equal(t, string(imageData), mediaResult.Data)
1599+
require.Equal(t, "image/png", mediaResult.MediaType)
1600+
})
1601+
1602+
t.Run("Media tool response (audio)", func(t *testing.T) {
1603+
t.Parallel()
1604+
1605+
audioTool := &mockTool{
1606+
name: "generate_audio",
1607+
description: "Generates audio",
1608+
executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1609+
return NewMediaResponse(audioData, "audio/wav"), nil
1610+
},
1611+
}
1612+
1613+
model := &mockLanguageModel{
1614+
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1615+
if len(call.Prompt) == 1 {
1616+
return &Response{
1617+
Content: []Content{
1618+
ToolCallContent{
1619+
ToolCallID: "audio-1",
1620+
ToolName: "generate_audio",
1621+
Input: `{}`,
1622+
},
1623+
},
1624+
Usage: Usage{TotalTokens: 10},
1625+
FinishReason: FinishReasonToolCalls,
1626+
}, nil
1627+
}
1628+
return &Response{
1629+
Content: []Content{TextContent{Text: "Audio generated"}},
1630+
Usage: Usage{TotalTokens: 20},
1631+
FinishReason: FinishReasonStop,
1632+
}, nil
1633+
},
1634+
}
1635+
1636+
agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
1637+
1638+
result, err := agent.Generate(context.Background(), AgentCall{
1639+
Prompt: "Generate audio",
1640+
})
1641+
1642+
require.NoError(t, err)
1643+
require.NotNil(t, result)
1644+
1645+
toolResults := result.Steps[0].Content.ToolResults()
1646+
require.Len(t, toolResults, 1)
1647+
1648+
mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1649+
require.True(t, ok, "Expected media result")
1650+
require.Equal(t, string(audioData), mediaResult.Data)
1651+
require.Equal(t, "audio/wav", mediaResult.MediaType)
1652+
})
1653+
1654+
t.Run("Media response with text", func(t *testing.T) {
1655+
t.Parallel()
1656+
1657+
imageTool := &mockTool{
1658+
name: "screenshot",
1659+
description: "Takes a screenshot",
1660+
executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1661+
resp := NewImageResponse(imageData, "image/png")
1662+
resp.Content = "Screenshot captured successfully"
1663+
return resp, nil
1664+
},
1665+
}
1666+
1667+
model := &mockLanguageModel{
1668+
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1669+
if len(call.Prompt) == 1 {
1670+
return &Response{
1671+
Content: []Content{
1672+
ToolCallContent{
1673+
ToolCallID: "screen-1",
1674+
ToolName: "screenshot",
1675+
Input: `{}`,
1676+
},
1677+
},
1678+
Usage: Usage{TotalTokens: 10},
1679+
FinishReason: FinishReasonToolCalls,
1680+
}, nil
1681+
}
1682+
return &Response{
1683+
Content: []Content{TextContent{Text: "Done"}},
1684+
Usage: Usage{TotalTokens: 20},
1685+
FinishReason: FinishReasonStop,
1686+
}, nil
1687+
},
1688+
}
1689+
1690+
agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1691+
1692+
result, err := agent.Generate(context.Background(), AgentCall{
1693+
Prompt: "Take a screenshot",
1694+
})
1695+
1696+
require.NoError(t, err)
1697+
require.NotNil(t, result)
1698+
1699+
toolResults := result.Steps[0].Content.ToolResults()
1700+
require.Len(t, toolResults, 1)
1701+
1702+
mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1703+
require.True(t, ok, "Expected media result")
1704+
require.Equal(t, string(imageData), mediaResult.Data)
1705+
require.Equal(t, "image/png", mediaResult.MediaType)
1706+
require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
1707+
})
1708+
1709+
t.Run("Media response preserves metadata", func(t *testing.T) {
1710+
t.Parallel()
1711+
1712+
type ImageMetadata struct {
1713+
Width int `json:"width"`
1714+
Height int `json:"height"`
1715+
}
1716+
1717+
imageTool := &mockTool{
1718+
name: "generate_image",
1719+
description: "Generates an image",
1720+
executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1721+
resp := NewImageResponse(imageData, "image/png")
1722+
return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
1723+
},
1724+
}
1725+
1726+
model := &mockLanguageModel{
1727+
generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1728+
if len(call.Prompt) == 1 {
1729+
return &Response{
1730+
Content: []Content{
1731+
ToolCallContent{
1732+
ToolCallID: "img-1",
1733+
ToolName: "generate_image",
1734+
Input: `{}`,
1735+
},
1736+
},
1737+
Usage: Usage{TotalTokens: 10},
1738+
FinishReason: FinishReasonToolCalls,
1739+
}, nil
1740+
}
1741+
return &Response{
1742+
Content: []Content{TextContent{Text: "Done"}},
1743+
Usage: Usage{TotalTokens: 20},
1744+
FinishReason: FinishReasonStop,
1745+
}, nil
1746+
},
1747+
}
1748+
1749+
agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1750+
1751+
result, err := agent.Generate(context.Background(), AgentCall{
1752+
Prompt: "Generate image",
1753+
})
1754+
1755+
require.NoError(t, err)
1756+
require.NotNil(t, result)
1757+
1758+
toolResults := result.Steps[0].Content.ToolResults()
1759+
require.Len(t, toolResults, 1)
1760+
1761+
// Check metadata was preserved
1762+
require.NotEmpty(t, toolResults[0].ClientMetadata)
1763+
1764+
var metadata ImageMetadata
1765+
err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
1766+
require.NoError(t, err)
1767+
require.Equal(t, 800, metadata.Width)
1768+
require.Equal(t, 600, metadata.Height)
1769+
})
1770+
}

content.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ func (t ToolResultOutputContentError) GetType() ToolResultContentType {
306306

307307
// ToolResultOutputContentMedia represents media output content of a tool result.
308308
type ToolResultOutputContentMedia struct {
309-
Data string `json:"data"` // for media type (base64)
310-
MediaType string `json:"media_type"` // for media type
309+
Data string `json:"data"` // for media type (base64)
310+
MediaType string `json:"media_type"` // for media type
311+
Text string `json:"text,omitempty"` // optional text content accompanying the media
311312
}
312313

313314
// GetType returns the type of the tool result output content media.

0 commit comments

Comments
 (0)