@@ -38,14 +38,15 @@ import (
3838 "github.com/coze-dev/coze-studio/backend/pkg/logs"
3939)
4040
41- func newReplyCallback (_ context.Context , executeID string ) (clb callbacks.Handler ,
41+ func newReplyCallback (_ context.Context , executeID string , returnDirectlyTools map [ string ] struct {} ) (clb callbacks.Handler ,
4242 sr * schema.StreamReader [* entity.AgentEvent ], sw * schema.StreamWriter [* entity.AgentEvent ],
4343) {
4444 sr , sw = schema.Pipe [* entity.AgentEvent ](10 )
4545
4646 rcc := & replyChunkCallback {
47- sw : sw ,
48- executeID : executeID ,
47+ sw : sw ,
48+ executeID : executeID ,
49+ returnDirectlyTools : returnDirectlyTools ,
4950 }
5051
5152 clb = callbacks .NewHandlerBuilder ().
@@ -59,8 +60,9 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
5960}
6061
6162type replyChunkCallback struct {
62- sw * schema.StreamWriter [* entity.AgentEvent ]
63- executeID string
63+ sw * schema.StreamWriter [* entity.AgentEvent ]
64+ executeID string
65+ returnDirectlyTools map [string ]struct {}
6466}
6567
6668func (r * replyChunkCallback ) OnError (ctx context.Context , info * callbacks.RunInfo , err error ) context.Context {
@@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
201203 }, nil )
202204 return ctx
203205 case compose .ComponentOfToolsNode :
204- toolsMessage , err := concatToolsNodeOutput (ctx , output )
206+ toolsMessage , err := r . concatToolsNodeOutput (ctx , output )
205207 if err != nil {
206208 r .sw .Send (nil , err )
207209 return ctx
@@ -270,37 +272,70 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
270272 return interruptEventType
271273}
272274
273- func concatToolsNodeOutput (ctx context.Context , output * schema.StreamReader [callbacks.CallbackOutput ]) ([]* schema.Message , error ) {
275+ func ( r * replyChunkCallback ) concatToolsNodeOutput (ctx context.Context , output * schema.StreamReader [callbacks.CallbackOutput ]) ([]* schema.Message , error ) {
274276 defer output .Close ()
275- toolsMsgChunks := make ([][]* schema.Message , 0 , 5 )
277+ var toolsMsgChunks [][]* schema.Message
278+ var sr * schema.StreamReader [* schema.Message ]
279+ var sw * schema.StreamWriter [* schema.Message ]
280+ defer func () {
281+ if sw != nil {
282+ sw .Close ()
283+ }
284+ }()
285+ var streamInitialized bool
286+ returnDirectToolsMap := make (map [int ]bool )
287+ isReturnDirectToolsFirstCheck := true
288+ isToolsMsgChunksInit := false
289+
276290 for {
277291 cbOut , err := output .Recv ()
278292 if errors .Is (err , io .EOF ) {
279293 break
280294 }
281295
282296 if err != nil {
297+ if sw != nil {
298+ sw .Send (nil , err )
299+ }
283300 return nil , err
284301 }
285302
286303 msgs := convToolsNodeCallbackOutput (cbOut )
287304
288- for _ , msg := range msgs {
305+ if ! isToolsMsgChunksInit {
306+ isToolsMsgChunksInit = true
307+ toolsMsgChunks = make ([][]* schema.Message , len (msgs ))
308+ }
309+
310+ for mIndex , msg := range msgs {
311+
289312 if msg == nil {
290313 continue
291314 }
315+ if len (r .returnDirectlyTools ) > 0 {
316+ if isReturnDirectToolsFirstCheck {
317+ isReturnDirectToolsFirstCheck = false
318+ if _ , ok := r .returnDirectlyTools [msg .ToolName ]; ok {
319+ returnDirectToolsMap [mIndex ] = true
320+ }
321+ }
292322
293- findSameMsg := false
294- for i , msgChunks := range toolsMsgChunks {
295- if msg .ToolCallID == msgChunks [0 ].ToolCallID {
296- toolsMsgChunks [i ] = append (toolsMsgChunks [i ], msg )
297- findSameMsg = true
298- break
323+ if _ , ok := returnDirectToolsMap [mIndex ]; ok {
324+ if ! streamInitialized {
325+ sr , sw = schema.Pipe [* schema.Message ](5 )
326+ r .sw .Send (& entity.AgentEvent {
327+ EventType : singleagent .EventTypeOfToolsAsChatModelStream ,
328+ ChatModelAnswer : sr ,
329+ }, nil )
330+ streamInitialized = true
331+ }
332+ sw .Send (msg , nil )
299333 }
300334 }
301-
302- if ! findSameMsg {
303- toolsMsgChunks = append (toolsMsgChunks , []* schema.Message {msg })
335+ if toolsMsgChunks [mIndex ] == nil {
336+ toolsMsgChunks [mIndex ] = []* schema.Message {msg }
337+ } else {
338+ toolsMsgChunks [mIndex ] = append (toolsMsgChunks [mIndex ], msg )
304339 }
305340 }
306341 }
0 commit comments