Skip to content

Commit f649733

Browse files
committed
chore: improve ai agent
Signed-off-by: Zzde <zhangxh1997@gmail.com>
1 parent 3d868b7 commit f649733

File tree

8 files changed

+411
-123
lines changed

8 files changed

+411
-123
lines changed

main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ func setupAPIRouter(r *gin.RouterGroup, cm *cluster.ClusterManager) {
213213
// AI chat routes
214214
api.GET("/ai/status", ai.HandleAIStatus)
215215
api.POST("/ai/chat", ai.HandleChat)
216-
api.POST("/ai/execute", ai.HandleExecute)
216+
api.POST("/ai/execute/continue", ai.HandleExecuteContinue)
217217

218218
api.Use(middleware.RBACMiddleware())
219219
resources.RegisterRoutes(api)

pkg/ai/agent.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,9 @@ func buildContextualSystemPrompt(pageCtx *PageContext, runtimeCtx runtimePromptC
280280
}
281281

282282
if language == "zh" {
283-
prompt += "\n\nResponse language:\n- Always respond in Simplified Chinese unless the user explicitly asks for another language."
283+
prompt += "\n\nResponse language:\n- Prefer replying in the same language as the user's latest message.\n- If the user's latest message language is unclear, respond in Simplified Chinese unless the user explicitly asks for another language."
284284
} else {
285-
prompt += "\n\nResponse language:\n- Always respond in English unless the user explicitly asks for another language."
285+
prompt += "\n\nResponse language:\n- Prefer replying in the same language as the user's latest message.\n- If the user's latest message language is unclear, respond in English unless the user explicitly asks for another language."
286286
}
287287

288288
klog.V(4).Infof("system prompt %s", prompt)
@@ -299,6 +299,20 @@ func (a *Agent) ProcessChat(c *gin.Context, req *ChatRequest, sendEvent func(SSE
299299
}
300300
}
301301

302+
func (a *Agent) ContinuePendingAction(c *gin.Context, sessionID string, sendEvent func(SSEEvent)) error {
303+
session, err := agentPendingSessions.take(sessionID)
304+
if err != nil {
305+
return err
306+
}
307+
308+
switch session.Provider {
309+
case model.GeneralAIProviderAnthropic:
310+
return a.continueChatAnthropic(c, session, sendEvent)
311+
default:
312+
return a.continueChatOpenAI(c, session, sendEvent)
313+
}
314+
}
315+
302316
func parseToolCallArguments(raw string) (map[string]interface{}, error) {
303317
raw = strings.TrimSpace(raw)
304318
if raw == "" {

pkg/ai/anthropic.go

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai
22

33
import (
4+
"context"
45
"fmt"
56
"strings"
67

@@ -34,6 +35,45 @@ func (a *Agent) processChatAnthropic(c *gin.Context, req *ChatRequest, sendEvent
3435
}
3536
sysPrompt := buildContextualSystemPrompt(req.PageContext, runtimeCtx, language)
3637
messages := toAnthropicMessages(req.Messages)
38+
a.runAnthropicConversation(ctx, c, sysPrompt, messages, sendEvent)
39+
}
40+
41+
func (a *Agent) continueChatAnthropic(c *gin.Context, session pendingSession, sendEvent func(SSEEvent)) error {
42+
ctx := c.Request.Context()
43+
result, isError := ExecuteTool(ctx, c, a.cs, session.ToolCall.Name, session.ToolCall.Args)
44+
45+
sendEvent(SSEEvent{
46+
Event: "tool_result",
47+
Data: map[string]interface{}{
48+
"tool": session.ToolCall.Name,
49+
"result": result,
50+
"is_error": isError,
51+
},
52+
})
53+
54+
toolResult := result
55+
if isError {
56+
toolResult = "Tool error: " + result
57+
}
58+
59+
messages := append([]anthropic.MessageParam(nil), session.AnthropicMessages...)
60+
messages = append(
61+
messages,
62+
anthropic.NewUserMessage(
63+
anthropic.NewToolResultBlock(session.ToolCall.ID, toolResult, isError),
64+
),
65+
)
66+
a.runAnthropicConversation(ctx, c, session.SystemPrompt, messages, sendEvent)
67+
return nil
68+
}
69+
70+
func (a *Agent) runAnthropicConversation(
71+
ctx context.Context,
72+
c *gin.Context,
73+
sysPrompt string,
74+
messages []anthropic.MessageParam,
75+
sendEvent func(SSEEvent),
76+
) {
3777
tools := AnthropicToolDefs()
3878

3979
maxIterations := 100
@@ -100,11 +140,25 @@ func (a *Agent) processChatAnthropic(c *gin.Context, req *ChatRequest, sendEvent
100140
toolResults = append(toolResults, anthropic.NewToolResultBlock(tc.ID, "Tool error: "+result, true))
101141
continue
102142
}
143+
if len(toolResults) > 0 {
144+
messages = append(messages, anthropic.NewUserMessage(toolResults...))
145+
}
146+
sessionID := agentPendingSessions.save(pendingSession{
147+
Provider: a.provider,
148+
SystemPrompt: sysPrompt,
149+
AnthropicMessages: append([]anthropic.MessageParam(nil), messages...),
150+
ToolCall: pendingToolCall{
151+
ID: tc.ID,
152+
Name: toolName,
153+
Args: args,
154+
},
155+
})
103156
sendEvent(SSEEvent{
104157
Event: "action_required",
105158
Data: map[string]interface{}{
106-
"tool": toolName,
107-
"args": args,
159+
"tool": toolName,
160+
"args": args,
161+
"session_id": sessionID,
108162
},
109163
})
110164
return

pkg/ai/handler.go

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,12 @@ func HandleChat(c *gin.Context) {
7676
sendEvent(SSEEvent{Event: "done", Data: map[string]string{}})
7777
}
7878

79-
// ExecuteRequest is the request body for the stateless execute endpoint.
80-
type ExecuteRequest struct {
81-
Tool string `json:"tool"`
82-
Args map[string]interface{} `json:"args"`
79+
type ContinueRequest struct {
80+
SessionID string `json:"sessionId"`
8381
}
8482

85-
// HandleExecute executes a confirmed mutation action. Stateless — the client
86-
// sends the full tool name and args, no server-side session needed.
87-
func HandleExecute(c *gin.Context) {
83+
// HandleExecuteContinue resumes a pending AI action after user confirmation.
84+
func HandleExecuteContinue(c *gin.Context) {
8885
cfg, err := LoadRuntimeConfig()
8986
if err != nil {
9087
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to load AI config: %v", err)})
@@ -95,14 +92,13 @@ func HandleExecute(c *gin.Context) {
9592
return
9693
}
9794

98-
var req ExecuteRequest
95+
var req ContinueRequest
9996
if err := c.ShouldBindJSON(&req); err != nil {
10097
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err)})
10198
return
10299
}
103-
104-
if !MutationTools[req.Tool] {
105-
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Tool %s is not a mutation tool", req.Tool)})
100+
if strings.TrimSpace(req.SessionID) == "" {
101+
c.JSON(http.StatusBadRequest, gin.H{"error": "sessionId is required"})
106102
return
107103
}
108104

@@ -112,25 +108,28 @@ func HandleExecute(c *gin.Context) {
112108
return
113109
}
114110

115-
result, isError := ExecuteTool(c.Request.Context(), c, clientSet, req.Tool, req.Args)
116-
if isError {
117-
statusCode := http.StatusInternalServerError
118-
if strings.HasPrefix(result, "Forbidden: ") {
119-
statusCode = http.StatusForbidden
120-
} else if strings.HasPrefix(result, "Error: ") || strings.HasPrefix(result, "Unknown tool: ") {
121-
statusCode = http.StatusBadRequest
122-
}
123-
c.JSON(statusCode, gin.H{
124-
"status": "error",
125-
"message": result,
126-
})
111+
agent, err := NewAgent(clientSet, cfg)
112+
if err != nil {
113+
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to create AI agent: %v", err)})
127114
return
128115
}
129116

130-
c.JSON(http.StatusOK, gin.H{
131-
"status": "ok",
132-
"message": result,
133-
})
117+
c.Header("Content-Type", "text/event-stream")
118+
c.Header("Cache-Control", "no-cache")
119+
c.Header("Connection", "keep-alive")
120+
c.Header("X-Accel-Buffering", "no")
121+
122+
sendEvent := func(event SSEEvent) {
123+
data := MarshalSSEEvent(event)
124+
_, _ = fmt.Fprint(c.Writer, data)
125+
c.Writer.Flush()
126+
}
127+
128+
if err := agent.ContinuePendingAction(c, req.SessionID, sendEvent); err != nil {
129+
sendEvent(SSEEvent{Event: "error", Data: map[string]string{"message": err.Error()}})
130+
}
131+
132+
sendEvent(SSEEvent{Event: "done", Data: map[string]string{}})
134133
}
135134

136135
func HandleGetGeneralSetting(c *gin.Context) {

pkg/ai/openai.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"sort"
@@ -37,7 +38,38 @@ func (a *Agent) processChatOpenAI(c *gin.Context, req *ChatRequest, sendEvent fu
3738
}
3839
sysPrompt := buildContextualSystemPrompt(req.PageContext, runtimeCtx, language)
3940
messages := toOpenAIMessages(sysPrompt, req.Messages)
41+
a.runOpenAIConversation(ctx, c, messages, sendEvent)
42+
}
43+
44+
func (a *Agent) continueChatOpenAI(c *gin.Context, session pendingSession, sendEvent func(SSEEvent)) error {
45+
ctx := c.Request.Context()
46+
result, isError := ExecuteTool(ctx, c, a.cs, session.ToolCall.Name, session.ToolCall.Args)
47+
48+
sendEvent(SSEEvent{
49+
Event: "tool_result",
50+
Data: map[string]interface{}{
51+
"tool": session.ToolCall.Name,
52+
"result": result,
53+
"is_error": isError,
54+
},
55+
})
56+
57+
if isError {
58+
result = "Tool error: " + result
59+
}
4060

61+
messages := append([]openai.ChatCompletionMessageParamUnion(nil), session.OpenAIMessages...)
62+
messages = append(messages, openai.ToolMessage(result, session.ToolCall.ID))
63+
a.runOpenAIConversation(ctx, c, messages, sendEvent)
64+
return nil
65+
}
66+
67+
func (a *Agent) runOpenAIConversation(
68+
ctx context.Context,
69+
c *gin.Context,
70+
messages []openai.ChatCompletionMessageParamUnion,
71+
sendEvent func(SSEEvent),
72+
) {
4173
tools := OpenAIToolDefs()
4274

4375
maxIterations := 100
@@ -107,11 +139,21 @@ func (a *Agent) processChatOpenAI(c *gin.Context, req *ChatRequest, sendEvent fu
107139
messages = append(messages, openai.ToolMessage("Tool error: "+result, tc.ID))
108140
continue
109141
}
142+
sessionID := agentPendingSessions.save(pendingSession{
143+
Provider: a.provider,
144+
OpenAIMessages: append([]openai.ChatCompletionMessageParamUnion(nil), messages...),
145+
ToolCall: pendingToolCall{
146+
ID: tc.ID,
147+
Name: toolName,
148+
Args: args,
149+
},
150+
})
110151
sendEvent(SSEEvent{
111152
Event: "action_required",
112153
Data: map[string]interface{}{
113-
"tool": toolName,
114-
"args": args,
154+
"tool": toolName,
155+
"args": args,
156+
"session_id": sessionID,
115157
},
116158
})
117159
return

pkg/ai/pending_session.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package ai
2+
3+
import (
4+
"crypto/rand"
5+
"encoding/hex"
6+
"fmt"
7+
"sync"
8+
"time"
9+
10+
anthropic "github.com/anthropics/anthropic-sdk-go"
11+
"github.com/openai/openai-go"
12+
)
13+
14+
const pendingSessionTTL = 15 * time.Minute
15+
16+
type pendingToolCall struct {
17+
ID string
18+
Name string
19+
Args map[string]interface{}
20+
}
21+
22+
type pendingSession struct {
23+
Provider string
24+
SystemPrompt string
25+
OpenAIMessages []openai.ChatCompletionMessageParamUnion
26+
AnthropicMessages []anthropic.MessageParam
27+
ToolCall pendingToolCall
28+
ExpiresAt time.Time
29+
}
30+
31+
type pendingSessionStore struct {
32+
mu sync.Mutex
33+
sessions map[string]pendingSession
34+
}
35+
36+
var agentPendingSessions = &pendingSessionStore{
37+
sessions: make(map[string]pendingSession),
38+
}
39+
40+
func (s *pendingSessionStore) save(session pendingSession) string {
41+
s.mu.Lock()
42+
defer s.mu.Unlock()
43+
44+
now := time.Now()
45+
s.cleanupExpiredLocked(now)
46+
sessionID := newPendingSessionID()
47+
session.ExpiresAt = now.Add(pendingSessionTTL)
48+
s.sessions[sessionID] = session
49+
return sessionID
50+
}
51+
52+
func (s *pendingSessionStore) take(sessionID string) (pendingSession, error) {
53+
s.mu.Lock()
54+
defer s.mu.Unlock()
55+
56+
s.cleanupExpiredLocked(time.Now())
57+
session, ok := s.sessions[sessionID]
58+
if !ok {
59+
return pendingSession{}, fmt.Errorf("pending action not found or expired")
60+
}
61+
delete(s.sessions, sessionID)
62+
return session, nil
63+
}
64+
65+
func (s *pendingSessionStore) cleanupExpiredLocked(now time.Time) {
66+
for id, session := range s.sessions {
67+
if now.After(session.ExpiresAt) {
68+
delete(s.sessions, id)
69+
}
70+
}
71+
}
72+
73+
func newPendingSessionID() string {
74+
buf := make([]byte, 16)
75+
if _, err := rand.Read(buf); err == nil {
76+
return hex.EncodeToString(buf)
77+
}
78+
return fmt.Sprintf("pending-%d", time.Now().UnixNano())
79+
}

0 commit comments

Comments
 (0)