Skip to content

Commit 0367e66

Browse files
authored
fix(singleagent): workflow as tool return directly (#526)
1 parent 38b63f0 commit 0367e66

File tree

5 files changed

+72
-34
lines changed

5 files changed

+72
-34
lines changed

backend/api/model/crossdomain/singleagent/single_agent.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@ type AgentRuntime struct {
3737
type EventType string
3838

3939
const (
40-
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
41-
EventTypeOfToolsMessage EventType = "tools_message"
42-
EventTypeOfFuncCall EventType = "func_call"
43-
EventTypeOfSuggest EventType = "suggest"
44-
EventTypeOfKnowledge EventType = "knowledge"
45-
EventTypeOfInterrupt EventType = "interrupt"
40+
EventTypeOfChatModelAnswer EventType = "chatmodel_answer"
41+
EventTypeOfToolsAsChatModelStream EventType = "tools_as_chatmodel_answer"
42+
EventTypeOfToolsMessage EventType = "tools_message"
43+
EventTypeOfFuncCall EventType = "func_call"
44+
EventTypeOfSuggest EventType = "suggest"
45+
EventTypeOfKnowledge EventType = "knowledge"
46+
EventTypeOfInterrupt EventType = "interrupt"
4647
)
4748

4849
type AgentEvent struct {

backend/domain/agent/singleagent/internal/agentflow/agent_flow_builder.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
113113
}
114114
tr := newPreToolRetriever(&toolPreCallConf{})
115115

116-
wfTools, toolsReturnDirectly, err := newWorkflowTools(ctx, &workflowConfig{
116+
wfTools, returnDirectlyTools, err := newWorkflowTools(ctx, &workflowConfig{
117117
wfInfos: conf.Agent.Workflow,
118118
})
119119
if err != nil {
@@ -176,7 +176,7 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
176176
ToolsConfig: compose.ToolsNodeConfig{
177177
Tools: agentTools,
178178
},
179-
ToolReturnDirectly: toolsReturnDirectly,
179+
ToolReturnDirectly: returnDirectlyTools,
180180
ModelNodeName: keyOfReActAgentChatModel,
181181
ToolsNodeName: keyOfReActAgentToolsNode,
182182
})
@@ -273,10 +273,11 @@ func BuildAgent(ctx context.Context, conf *Config) (r *AgentRunner, err error) {
273273
}
274274

275275
return &AgentRunner{
276-
runner: runner,
277-
requireCheckpoint: requireCheckpoint,
278-
modelInfo: modelInfo,
279-
containWfTool: containWfTool,
276+
runner: runner,
277+
requireCheckpoint: requireCheckpoint,
278+
modelInfo: modelInfo,
279+
containWfTool: containWfTool,
280+
returnDirectlyTools: returnDirectlyTools,
280281
}, nil
281282
}
282283

backend/domain/agent/singleagent/internal/agentflow/agent_flow_runner.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,17 @@ type AgentRunner struct {
5757
runner compose.Runnable[*AgentRequest, *schema.Message]
5858
requireCheckpoint bool
5959

60-
containWfTool bool
61-
modelInfo *modelmgr.Model
60+
returnDirectlyTools map[string]struct{}
61+
containWfTool bool
62+
modelInfo *modelmgr.Model
6263
}
6364

6465
func (r *AgentRunner) StreamExecute(ctx context.Context, req *AgentRequest) (
6566
sr *schema.StreamReader[*entity.AgentEvent], err error,
6667
) {
6768
executeID := uuid.New()
6869

69-
hdl, sr, sw := newReplyCallback(ctx, executeID.String())
70+
hdl, sr, sw := newReplyCallback(ctx, executeID.String(), r.returnDirectlyTools)
7071

7172
go func() {
7273
defer func() {

backend/domain/agent/singleagent/internal/agentflow/callback_reply_chunk.go

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ import (
3838
"github.com/coze-dev/coze-studio/backend/pkg/logs"
3939
)
4040

41-
func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handler,
41+
func newReplyCallback(_ context.Context, executeID string, returnDirectlyTools map[string]struct{}) (clb callbacks.Handler,
4242
sr *schema.StreamReader[*entity.AgentEvent], sw *schema.StreamWriter[*entity.AgentEvent],
4343
) {
4444
sr, sw = schema.Pipe[*entity.AgentEvent](10)
4545

4646
rcc := &replyChunkCallback{
47-
sw: sw,
48-
executeID: executeID,
47+
sw: sw,
48+
executeID: executeID,
49+
returnDirectlyTools: returnDirectlyTools,
4950
}
5051

5152
clb = callbacks.NewHandlerBuilder().
@@ -59,8 +60,9 @@ func newReplyCallback(_ context.Context, executeID string) (clb callbacks.Handle
5960
}
6061

6162
type replyChunkCallback struct {
62-
sw *schema.StreamWriter[*entity.AgentEvent]
63-
executeID string
63+
sw *schema.StreamWriter[*entity.AgentEvent]
64+
executeID string
65+
returnDirectlyTools map[string]struct{}
6466
}
6567

6668
func (r *replyChunkCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
@@ -201,7 +203,7 @@ func (r *replyChunkCallback) OnEndWithStreamOutput(ctx context.Context, info *ca
201203
}, nil)
202204
return ctx
203205
case compose.ComponentOfToolsNode:
204-
toolsMessage, err := concatToolsNodeOutput(ctx, output)
206+
toolsMessage, err := r.concatToolsNodeOutput(ctx, output)
205207
if err != nil {
206208
r.sw.Send(nil, err)
207209
return ctx
@@ -270,37 +272,70 @@ func convInterruptEventType(interruptEvent any) singleagent.InterruptEventType {
270272
return interruptEventType
271273
}
272274

273-
func concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
275+
func (r *replyChunkCallback) concatToolsNodeOutput(ctx context.Context, output *schema.StreamReader[callbacks.CallbackOutput]) ([]*schema.Message, error) {
274276
defer output.Close()
275-
toolsMsgChunks := make([][]*schema.Message, 0, 5)
277+
var toolsMsgChunks [][]*schema.Message
278+
var sr *schema.StreamReader[*schema.Message]
279+
var sw *schema.StreamWriter[*schema.Message]
280+
defer func() {
281+
if sw != nil {
282+
sw.Close()
283+
}
284+
}()
285+
var streamInitialized bool
286+
returnDirectToolsMap := make(map[int]bool)
287+
isReturnDirectToolsFirstCheck := true
288+
isToolsMsgChunksInit := false
289+
276290
for {
277291
cbOut, err := output.Recv()
278292
if errors.Is(err, io.EOF) {
279293
break
280294
}
281295

282296
if err != nil {
297+
if sw != nil {
298+
sw.Send(nil, err)
299+
}
283300
return nil, err
284301
}
285302

286303
msgs := convToolsNodeCallbackOutput(cbOut)
287304

288-
for _, msg := range msgs {
305+
if !isToolsMsgChunksInit {
306+
isToolsMsgChunksInit = true
307+
toolsMsgChunks = make([][]*schema.Message, len(msgs))
308+
}
309+
310+
for mIndex, msg := range msgs {
311+
289312
if msg == nil {
290313
continue
291314
}
315+
if len(r.returnDirectlyTools) > 0 {
316+
if isReturnDirectToolsFirstCheck {
317+
isReturnDirectToolsFirstCheck = false
318+
if _, ok := r.returnDirectlyTools[msg.ToolName]; ok {
319+
returnDirectToolsMap[mIndex] = true
320+
}
321+
}
292322

293-
findSameMsg := false
294-
for i, msgChunks := range toolsMsgChunks {
295-
if msg.ToolCallID == msgChunks[0].ToolCallID {
296-
toolsMsgChunks[i] = append(toolsMsgChunks[i], msg)
297-
findSameMsg = true
298-
break
323+
if _, ok := returnDirectToolsMap[mIndex]; ok {
324+
if !streamInitialized {
325+
sr, sw = schema.Pipe[*schema.Message](5)
326+
r.sw.Send(&entity.AgentEvent{
327+
EventType: singleagent.EventTypeOfToolsAsChatModelStream,
328+
ChatModelAnswer: sr,
329+
}, nil)
330+
streamInitialized = true
331+
}
332+
sw.Send(msg, nil)
299333
}
300334
}
301-
302-
if !findSameMsg {
303-
toolsMsgChunks = append(toolsMsgChunks, []*schema.Message{msg})
335+
if toolsMsgChunks[mIndex] == nil {
336+
toolsMsgChunks[mIndex] = []*schema.Message{msg}
337+
} else {
338+
toolsMsgChunks[mIndex] = append(toolsMsgChunks[mIndex], msg)
304339
}
305340
}
306341
}

backend/domain/conversation/agentrun/service/agent_run_impl.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func transformEventMap(eventType singleagent.EventType) (message.MessageType, er
200200
return message.MessageTypeKnowledge, nil
201201
case singleagent.EventTypeOfToolsMessage:
202202
return message.MessageTypeToolResponse, nil
203-
case singleagent.EventTypeOfChatModelAnswer:
203+
case singleagent.EventTypeOfChatModelAnswer, singleagent.EventTypeOfToolsAsChatModelStream:
204204
return message.MessageTypeAnswer, nil
205205
case singleagent.EventTypeOfSuggest:
206206
return message.MessageTypeFlowUp, nil

0 commit comments

Comments
 (0)