Skip to content

Commit 03613bc

Browse files
fix(adk): set streaming meta for agentic tool chunks
Change-Id: I4d42e16cd420690625342d5097df4dbcae07cb8b
1 parent 804b9d4 commit 03613bc

5 files changed

Lines changed: 212 additions & 11 deletions

File tree

adk/wrappers.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,18 @@ func functionToolResultAgenticMessage(callID, name string, content []*schema.Fun
594594
}
595595
}
596596

597+
func markAgenticMessageStreamingMeta(msg *schema.AgenticMessage, index int) {
598+
if msg == nil {
599+
return
600+
}
601+
for _, block := range msg.ContentBlocks {
602+
if block == nil {
603+
continue
604+
}
605+
block.StreamingMeta = &schema.StreamingMeta{Index: index}
606+
}
607+
}
608+
597609
func toolSearchResultAgenticMessage(callID, name string, tr *schema.ToolResult) (*schema.AgenticMessage, bool) {
598610
if tr == nil || len(tr.Parts) != 1 {
599611
return nil, false
@@ -745,6 +757,7 @@ func typedToolStreamEvent[M MessageType](callID, toolName, toolMsgID string, str
745757
first := true
746758
cvt := func(in string) (*schema.AgenticMessage, error) {
747759
msg := functionToolResultAgenticMessage(callID, toolName, textToFunctionToolResultBlocks(in))
760+
markAgenticMessageStreamingMeta(msg, 0)
748761
if first {
749762
first = false
750763
msg.Extra = internal.SetMessageID(msg.Extra, toolMsgID)
@@ -813,6 +826,7 @@ func typedToolEnhancedStreamEvent[M MessageType](callID, toolName, toolMsgID str
813826
first := true
814827
cvt := func(in *schema.ToolResult) (*schema.AgenticMessage, error) {
815828
msg := toolResultAgenticMessage(callID, toolName, in)
829+
markAgenticMessageStreamingMeta(msg, 0)
816830
if first {
817831
first = false
818832
msg.Extra = internal.SetMessageID(msg.Extra, toolMsgID)

adk/wrappers_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,6 +1974,73 @@ func TestAgenticEventSenderToolHandler(t *testing.T) {
19741974
})
19751975
}
19761976

1977+
func TestTypedToolStreamEventAgenticMessageSetsStreamingMeta(t *testing.T) {
1978+
event := typedToolStreamEvent[*schema.AgenticMessage](
1979+
"call_1",
1980+
"execute",
1981+
"msg_1",
1982+
schema.StreamReaderFromArray([]string{"first\n", "second\n"}),
1983+
)
1984+
require.NotNil(t, event)
1985+
require.NotNil(t, event.Output)
1986+
require.NotNil(t, event.Output.MessageOutput)
1987+
require.True(t, event.Output.MessageOutput.IsStreaming)
1988+
require.NotNil(t, event.Output.MessageOutput.MessageStream)
1989+
1990+
first, err := event.Output.MessageOutput.MessageStream.Recv()
1991+
require.NoError(t, err)
1992+
require.Len(t, first.ContentBlocks, 1)
1993+
assert.Equal(t, &schema.StreamingMeta{Index: 0}, first.ContentBlocks[0].StreamingMeta)
1994+
1995+
second, err := event.Output.MessageOutput.MessageStream.Recv()
1996+
require.NoError(t, err)
1997+
require.Len(t, second.ContentBlocks, 1)
1998+
assert.Equal(t, &schema.StreamingMeta{Index: 0}, second.ContentBlocks[0].StreamingMeta)
1999+
2000+
result, err := schema.ConcatAgenticMessages([]*schema.AgenticMessage{first, second})
2001+
require.NoError(t, err)
2002+
require.Len(t, result.ContentBlocks, 1)
2003+
assert.Nil(t, result.ContentBlocks[0].StreamingMeta)
2004+
require.NotNil(t, result.ContentBlocks[0].FunctionToolResult)
2005+
require.Len(t, result.ContentBlocks[0].FunctionToolResult.Content, 1)
2006+
assert.Equal(t, "first\nsecond\n", result.ContentBlocks[0].FunctionToolResult.Content[0].Text.Text)
2007+
}
2008+
2009+
func TestTypedToolEnhancedStreamEventAgenticMessageSetsStreamingMeta(t *testing.T) {
2010+
event := typedToolEnhancedStreamEvent[*schema.AgenticMessage](
2011+
"call_1",
2012+
"execute",
2013+
"msg_1",
2014+
schema.StreamReaderFromArray([]*schema.ToolResult{
2015+
{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "first\n"}}},
2016+
{Parts: []schema.ToolOutputPart{{Type: schema.ToolPartTypeText, Text: "second\n"}}},
2017+
}),
2018+
)
2019+
require.NotNil(t, event)
2020+
require.NotNil(t, event.Output)
2021+
require.NotNil(t, event.Output.MessageOutput)
2022+
require.True(t, event.Output.MessageOutput.IsStreaming)
2023+
require.NotNil(t, event.Output.MessageOutput.MessageStream)
2024+
2025+
first, err := event.Output.MessageOutput.MessageStream.Recv()
2026+
require.NoError(t, err)
2027+
require.Len(t, first.ContentBlocks, 1)
2028+
assert.Equal(t, &schema.StreamingMeta{Index: 0}, first.ContentBlocks[0].StreamingMeta)
2029+
2030+
second, err := event.Output.MessageOutput.MessageStream.Recv()
2031+
require.NoError(t, err)
2032+
require.Len(t, second.ContentBlocks, 1)
2033+
assert.Equal(t, &schema.StreamingMeta{Index: 0}, second.ContentBlocks[0].StreamingMeta)
2034+
2035+
result, err := schema.ConcatAgenticMessages([]*schema.AgenticMessage{first, second})
2036+
require.NoError(t, err)
2037+
require.Len(t, result.ContentBlocks, 1)
2038+
assert.Nil(t, result.ContentBlocks[0].StreamingMeta)
2039+
require.NotNil(t, result.ContentBlocks[0].FunctionToolResult)
2040+
require.Len(t, result.ContentBlocks[0].FunctionToolResult.Content, 1)
2041+
assert.Equal(t, "first\nsecond\n", result.ContentBlocks[0].FunctionToolResult.Content[0].Text.Text)
2042+
}
2043+
19772044
// multimodalEnhancedInvokableTestTool returns a pre-built multimodal ToolResult.
19782045
type multimodalEnhancedInvokableTestTool struct {
19792046
name string
@@ -2091,6 +2158,7 @@ func TestTypedToolEnhancedEventAgenticToolSearchResult(t *testing.T) {
20912158
require.Len(t, msg.ContentBlocks, 1)
20922159
block := msg.ContentBlocks[0]
20932160
assert.Equal(t, schema.ContentBlockTypeToolSearchResult, block.Type)
2161+
assert.Equal(t, &schema.StreamingMeta{Index: 0}, block.StreamingMeta)
20942162
require.NotNil(t, block.ToolSearchFunctionToolResult)
20952163
assert.Equal(t, "call_2", block.ToolSearchFunctionToolResult.CallID)
20962164
assert.Equal(t, "tool_search", block.ToolSearchFunctionToolResult.Name)

compose/agentic_tools_node_test.go

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
package compose
1818

1919
import (
20+
"context"
2021
"io"
2122
"testing"
2223

2324
"github.com/bytedance/sonic"
2425
"github.com/stretchr/testify/assert"
2526
"github.com/stretchr/testify/require"
2627

28+
"github.com/cloudwego/eino/components/tool"
2729
"github.com/cloudwego/eino/schema"
2830
)
2931

@@ -343,6 +345,57 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) {
343345
})
344346
}
345347

348+
func TestAgenticToolsNodeStreamSetsStreamingMeta(t *testing.T) {
349+
ctx := context.Background()
350+
node, err := NewAgenticToolsNode(ctx, &ToolsNodeConfig{
351+
Tools: []tool.BaseTool{&mockTool{}},
352+
})
353+
require.NoError(t, err)
354+
355+
stream, err := node.Stream(ctx, &schema.AgenticMessage{
356+
ContentBlocks: []*schema.ContentBlock{
357+
{
358+
Type: schema.ContentBlockTypeFunctionToolCall,
359+
FunctionToolCall: &schema.FunctionToolCall{
360+
CallID: "call_1",
361+
Name: "mock_tool",
362+
Arguments: `{"name":"jack"}`,
363+
},
364+
},
365+
},
366+
})
367+
require.NoError(t, err)
368+
defer stream.Close()
369+
370+
var chunks [][]*schema.AgenticMessage
371+
for {
372+
chunk, err := stream.Recv()
373+
if err == io.EOF {
374+
break
375+
}
376+
require.NoError(t, err)
377+
require.Len(t, chunk, 1)
378+
require.Len(t, chunk[0].ContentBlocks, 1)
379+
block := chunk[0].ContentBlocks[0]
380+
assert.Equal(t, schema.ContentBlockTypeFunctionToolResult, block.Type)
381+
assert.Equal(t, &schema.StreamingMeta{Index: 0}, block.StreamingMeta)
382+
chunks = append(chunks, chunk)
383+
}
384+
require.NotEmpty(t, chunks)
385+
386+
result, err := schema.ConcatAgenticMessagesArray(chunks)
387+
require.NoError(t, err)
388+
require.Len(t, result, 1)
389+
require.Len(t, result[0].ContentBlocks, 1)
390+
block := result[0].ContentBlocks[0]
391+
assert.Nil(t, block.StreamingMeta)
392+
require.NotNil(t, block.FunctionToolResult)
393+
assert.Equal(t, "call_1", block.FunctionToolResult.CallID)
394+
assert.Equal(t, "mock_tool", block.FunctionToolResult.Name)
395+
require.Len(t, block.FunctionToolResult.Content, 1)
396+
assert.JSONEq(t, `{"echo":"jack: 0"}`, block.FunctionToolResult.Content[0].Text.Text)
397+
}
398+
346399
func testStreamToolMessageTextOnly(t *testing.T) {
347400
input := schema.StreamReaderFromArray([][]*schema.Message{
348401
{
@@ -435,8 +488,7 @@ func testStreamToolMessageTextOnly(t *testing.T) {
435488
CallID: "2",
436489
Name: "name2",
437490
Content: []*schema.FunctionToolResultContentBlock{
438-
{Type: schema.FunctionToolResultContentBlockTypeText, Text: &schema.UserInputText{Text: "content2-1"}},
439-
{Type: schema.FunctionToolResultContentBlockTypeText, Text: &schema.UserInputText{Text: "content2-2"}},
491+
{Type: schema.FunctionToolResultContentBlockTypeText, Text: &schema.UserInputText{Text: "content2-1content2-2"}},
440492
},
441493
},
442494
},
@@ -451,8 +503,7 @@ func testStreamToolMessageTextOnly(t *testing.T) {
451503
CallID: "3",
452504
Name: "name3",
453505
Content: []*schema.FunctionToolResultContentBlock{
454-
{Type: schema.FunctionToolResultContentBlockTypeText, Text: &schema.UserInputText{Text: "content3-1"}},
455-
{Type: schema.FunctionToolResultContentBlockTypeText, Text: &schema.UserInputText{Text: "content3-2"}},
506+
{Type: schema.FunctionToolResultContentBlockTypeText, Text: &schema.UserInputText{Text: "content3-1content3-2"}},
456507
},
457508
},
458509
},

schema/agentic_message.go

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,12 +1623,66 @@ func concatFunctionToolResults(results []*FunctionToolResult) (*FunctionToolResu
16231623
return nil, fmt.Errorf("expected tool name '%s' for function tool result, but got '%s'", ret.Name, r.Name)
16241624
}
16251625

1626-
for _, b := range r.Content {
1627-
if b == nil {
1628-
continue
1626+
var err error
1627+
ret.Content, err = concatFunctionToolResultContent(ret.Content, r.Content)
1628+
if err != nil {
1629+
return nil, err
1630+
}
1631+
}
1632+
1633+
return ret, nil
1634+
}
1635+
1636+
func concatFunctionToolResultContent(
1637+
left, right []*FunctionToolResultContentBlock,
1638+
) ([]*FunctionToolResultContentBlock, error) {
1639+
ret := append([]*FunctionToolResultContentBlock(nil), left...)
1640+
for _, block := range right {
1641+
if block == nil {
1642+
continue
1643+
}
1644+
if len(ret) > 0 && canConcatFunctionToolResultTextBlocks(ret[len(ret)-1], block) {
1645+
merged, err := concatFunctionToolResultTextBlocks(ret[len(ret)-1], block)
1646+
if err != nil {
1647+
return nil, err
16291648
}
1630-
ret.Content = append(ret.Content, b)
1649+
ret[len(ret)-1] = merged
1650+
continue
1651+
}
1652+
ret = append(ret, block)
1653+
}
1654+
1655+
return ret, nil
1656+
}
1657+
1658+
func canConcatFunctionToolResultTextBlocks(a, b *FunctionToolResultContentBlock) bool {
1659+
return a != nil && b != nil &&
1660+
a.Type == FunctionToolResultContentBlockTypeText &&
1661+
b.Type == FunctionToolResultContentBlockTypeText &&
1662+
a.Text != nil && b.Text != nil
1663+
}
1664+
1665+
func concatFunctionToolResultTextBlocks(
1666+
a, b *FunctionToolResultContentBlock,
1667+
) (*FunctionToolResultContentBlock, error) {
1668+
ret := &FunctionToolResultContentBlock{
1669+
Type: FunctionToolResultContentBlockTypeText,
1670+
Text: &UserInputText{Text: a.Text.Text + b.Text.Text},
1671+
}
1672+
1673+
var extras []map[string]any
1674+
if len(a.Extra) > 0 {
1675+
extras = append(extras, a.Extra)
1676+
}
1677+
if len(b.Extra) > 0 {
1678+
extras = append(extras, b.Extra)
1679+
}
1680+
if len(extras) > 0 {
1681+
extra, err := concatExtra(extras)
1682+
if err != nil {
1683+
return nil, fmt.Errorf("failed to concat function tool result content extras: %w", err)
16311684
}
1685+
ret.Extra = extra
16321686
}
16331687

16341688
return ret, nil

schema/agentic_message_test.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,8 @@ func TestConcatAgenticMessages(t *testing.T) {
526526
assert.Len(t, result.ContentBlocks, 1)
527527
assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolResult.CallID)
528528
assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolResult.Name)
529-
assert.Equal(t, 2, len(result.ContentBlocks[0].FunctionToolResult.Content))
530-
assert.Equal(t, `{"temp`, result.ContentBlocks[0].FunctionToolResult.Content[0].Text.Text)
531-
assert.Equal(t, `":72}`, result.ContentBlocks[0].FunctionToolResult.Content[1].Text.Text)
529+
assert.Equal(t, 1, len(result.ContentBlocks[0].FunctionToolResult.Content))
530+
assert.Equal(t, `{"temp":72}`, result.ContentBlocks[0].FunctionToolResult.Content[0].Text.Text)
532531
})
533532

534533
t.Run("concat server tool call", func(t *testing.T) {
@@ -1726,4 +1725,19 @@ func TestConcatFunctionToolResults(t *testing.T) {
17261725
assert.Equal(t, "hello", got.Content[0].Text.Text)
17271726
assert.Equal(t, "http://img.png", got.Content[1].Image.URL)
17281727
})
1728+
1729+
t.Run("text chunks", func(t *testing.T) {
1730+
results := []*FunctionToolResult{
1731+
{CallID: "c1", Name: "tool1", Content: []*FunctionToolResultContentBlock{
1732+
{Type: FunctionToolResultContentBlockTypeText, Text: &UserInputText{Text: "hello "}},
1733+
}},
1734+
{CallID: "c1", Name: "tool1", Content: []*FunctionToolResultContentBlock{
1735+
{Type: FunctionToolResultContentBlockTypeText, Text: &UserInputText{Text: "world"}},
1736+
}},
1737+
}
1738+
got, err := concatFunctionToolResults(results)
1739+
require.NoError(t, err)
1740+
require.Len(t, got.Content, 1)
1741+
assert.Equal(t, "hello world", got.Content[0].Text.Text)
1742+
})
17291743
}

0 commit comments

Comments
 (0)