Skip to content

Commit f7cf7c2

Browse files
committed
chore: improve ai agent
Signed-off-by: Zzde <zhangxh1997@gmail.com>
1 parent 030ec6d commit f7cf7c2

File tree

13 files changed

+333
-66
lines changed

13 files changed

+333
-66
lines changed

pkg/ai/anthropic.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func (a *Agent) runAnthropicConversation(
7373
messages []anthropic.MessageParam,
7474
sendEvent func(SSEEvent),
7575
) {
76-
tools := AnthropicToolDefs()
76+
tools := AnthropicToolDefs(a.cs)
7777

7878
maxIterations := 100
7979
for i := 0; i < maxIterations; i++ {

pkg/ai/openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (a *Agent) runOpenAIConversation(
6969
messages []openai.ChatCompletionMessageParamUnion,
7070
sendEvent func(SSEEvent),
7171
) {
72-
tools := OpenAIToolDefs()
72+
tools := OpenAIToolDefs(a.cs)
7373

7474
maxIterations := 100
7575
for i := 0; i < maxIterations; i++ {

pkg/ai/pending_session.go

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"crypto/rand"
55
"encoding/hex"
66
"fmt"
7-
"sync"
87
"time"
98

109
anthropic "github.com/anthropics/anthropic-sdk-go"
1110
"github.com/openai/openai-go"
11+
"github.com/zxh326/kite/pkg/model"
12+
"k8s.io/klog/v2"
1213
)
1314

1415
const pendingSessionTTL = 15 * time.Minute
@@ -28,46 +29,85 @@ type pendingSession struct {
2829
ExpiresAt time.Time
2930
}
3031

31-
type pendingSessionStore struct {
32-
mu sync.Mutex
33-
sessions map[string]pendingSession
34-
}
32+
type pendingSessionStore struct{}
3533

36-
var agentPendingSessions = &pendingSessionStore{
37-
sessions: make(map[string]pendingSession),
38-
}
34+
var agentPendingSessions = &pendingSessionStore{}
3935

4036
func (s *pendingSessionStore) save(session pendingSession) string {
41-
s.mu.Lock()
42-
defer s.mu.Unlock()
43-
44-
now := time.Now()
45-
s.cleanupExpiredLocked(now)
4637
sessionID := newPendingSessionID()
47-
session.ExpiresAt = now.Add(pendingSessionTTL)
48-
s.sessions[sessionID] = session
38+
session.ExpiresAt = time.Now().Add(pendingSessionTTL)
39+
40+
dbSession := &model.PendingSession{
41+
SessionID: sessionID,
42+
Provider: session.Provider,
43+
SystemPrompt: session.SystemPrompt,
44+
ToolCallID: session.ToolCall.ID,
45+
ToolCallName: session.ToolCall.Name,
46+
ExpiresAt: session.ExpiresAt,
47+
}
48+
49+
// Marshal messages and args
50+
if err := dbSession.OpenAIMessages.Marshal(session.OpenAIMessages); err != nil {
51+
klog.Errorf("Failed to marshal OpenAI messages: %v", err)
52+
return ""
53+
}
54+
if err := dbSession.AnthropicMessages.Marshal(session.AnthropicMessages); err != nil {
55+
klog.Errorf("Failed to marshal Anthropic messages: %v", err)
56+
return ""
57+
}
58+
if err := dbSession.ToolCallArgs.Marshal(session.ToolCall.Args); err != nil {
59+
klog.Errorf("Failed to marshal tool call args: %v", err)
60+
return ""
61+
}
62+
63+
if err := model.SavePendingSession(dbSession); err != nil {
64+
klog.Errorf("Failed to save pending session: %v", err)
65+
return ""
66+
}
67+
68+
// Cleanup expired sessions asynchronously
69+
go func() {
70+
if err := model.CleanupExpiredPendingSessions(); err != nil {
71+
klog.V(4).Infof("Failed to cleanup expired pending sessions: %v", err)
72+
}
73+
}()
74+
4975
return sessionID
5076
}
5177

5278
func (s *pendingSessionStore) take(sessionID string) (pendingSession, error) {
53-
s.mu.Lock()
54-
defer s.mu.Unlock()
55-
56-
s.cleanupExpiredLocked(time.Now())
57-
session, ok := s.sessions[sessionID]
58-
if !ok {
79+
dbSession, err := model.GetPendingSession(sessionID)
80+
if err != nil {
5981
return pendingSession{}, fmt.Errorf("pending action not found or expired")
6082
}
61-
delete(s.sessions, sessionID)
62-
return session, nil
63-
}
6483

65-
func (s *pendingSessionStore) cleanupExpiredLocked(now time.Time) {
66-
for id, session := range s.sessions {
67-
if now.After(session.ExpiresAt) {
68-
delete(s.sessions, id)
69-
}
84+
// Delete the session immediately after retrieving it
85+
if err := model.DeletePendingSession(sessionID); err != nil {
86+
klog.Warningf("Failed to delete pending session %s: %v", sessionID, err)
87+
}
88+
89+
session := pendingSession{
90+
Provider: dbSession.Provider,
91+
SystemPrompt: dbSession.SystemPrompt,
92+
ExpiresAt: dbSession.ExpiresAt,
93+
ToolCall: pendingToolCall{
94+
ID: dbSession.ToolCallID,
95+
Name: dbSession.ToolCallName,
96+
},
7097
}
98+
99+
// Unmarshal messages and args
100+
if session.OpenAIMessages, err = dbSession.UnmarshalOpenAIMessages(); err != nil {
101+
return pendingSession{}, fmt.Errorf("failed to unmarshal OpenAI messages: %w", err)
102+
}
103+
if session.AnthropicMessages, err = dbSession.UnmarshalAnthropicMessages(); err != nil {
104+
return pendingSession{}, fmt.Errorf("failed to unmarshal Anthropic messages: %w", err)
105+
}
106+
if session.ToolCall.Args, err = dbSession.UnmarshalToolCallArgs(); err != nil {
107+
return pendingSession{}, fmt.Errorf("failed to unmarshal tool call args: %w", err)
108+
}
109+
110+
return session, nil
71111
}
72112

73113
func newPendingSessionID() string {

pkg/ai/tools.go

Lines changed: 112 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ type agentToolDefinition struct {
3939
Required []string
4040
}
4141

42-
func toolDefinitions() []agentToolDefinition {
43-
return []agentToolDefinition{
42+
func toolDefinitions(cs *cluster.ClientSet) []agentToolDefinition {
43+
tools := []agentToolDefinition{
4444
{
4545
Name: "get_resource",
4646
Description: "Get a specific Kubernetes resource by kind, name, and optionally namespace. Returns the resource details in YAML format.",
@@ -175,7 +175,11 @@ func toolDefinitions() []agentToolDefinition {
175175
},
176176
Required: []string{"kind", "name"},
177177
},
178-
{
178+
}
179+
180+
// Only add Prometheus tool if Prometheus client is available
181+
if cs != nil && cs.PromClient != nil {
182+
tools = append(tools, agentToolDefinition{
179183
Name: "query_prometheus",
180184
Description: "Execute a PromQL query against Prometheus to retrieve metrics data. Use this to get monitoring information like CPU usage, memory usage, network traffic, custom application metrics, etc. Returns time series data or instant values. Note: Requires cluster-wide read access as metrics can span multiple namespaces.",
181185
Properties: map[string]any{
@@ -194,12 +198,14 @@ func toolDefinitions() []agentToolDefinition {
194198
},
195199
},
196200
Required: []string{"query"},
197-
},
201+
})
198202
}
203+
204+
return tools
199205
}
200206

201-
func OpenAIToolDefs() []openai.ChatCompletionToolParam {
202-
defs := toolDefinitions()
207+
func OpenAIToolDefs(cs *cluster.ClientSet) []openai.ChatCompletionToolParam {
208+
defs := toolDefinitions(cs)
203209
tools := make([]openai.ChatCompletionToolParam, 0, len(defs))
204210

205211
for _, def := range defs {
@@ -223,8 +229,8 @@ func OpenAIToolDefs() []openai.ChatCompletionToolParam {
223229
return tools
224230
}
225231

226-
func AnthropicToolDefs() []anthropic.ToolUnionParam {
227-
defs := toolDefinitions()
232+
func AnthropicToolDefs(cs *cluster.ClientSet) []anthropic.ToolUnionParam {
233+
defs := toolDefinitions(cs)
228234
tools := make([]anthropic.ToolUnionParam, 0, len(defs))
229235

230236
for _, def := range defs {
@@ -660,6 +666,8 @@ func ExecuteTool(ctx context.Context, c *gin.Context, cs *cluster.ClientSet, too
660666
return result, true
661667
}
662668

669+
user, _ := currentUserFromGin(c)
670+
663671
switch toolName {
664672
case "get_resource":
665673
return executeGetResource(ctx, cs, args)
@@ -670,20 +678,56 @@ func ExecuteTool(ctx context.Context, c *gin.Context, cs *cluster.ClientSet, too
670678
case "get_cluster_overview":
671679
return executeGetClusterOverview(ctx, cs)
672680
case "create_resource":
673-
return executeCreateResource(ctx, cs, args)
681+
return executeCreateResource(ctx, cs, user, args)
674682
case "update_resource":
675-
return executeUpdateResource(ctx, cs, args)
683+
return executeUpdateResource(ctx, cs, user, args)
676684
case "patch_resource":
677-
return executePatchResource(ctx, cs, args)
685+
return executePatchResource(ctx, cs, user, args)
678686
case "delete_resource":
679-
return executeDeleteResource(ctx, cs, args)
687+
return executeDeleteResource(ctx, cs, user, args)
680688
case "query_prometheus":
681689
return executeQueryPrometheus(ctx, cs, args)
682690
default:
683691
return fmt.Sprintf("Unknown tool: %s", toolName), true
684692
}
685693
}
686694

695+
func recordResourceHistory(cs *cluster.ClientSet, user pkgmodel.User, kind, name, namespace, opType, resourceYAML, previousYAML string, success bool, err error) {
696+
errMsg := ""
697+
if err != nil {
698+
errMsg = err.Error()
699+
}
700+
701+
history := pkgmodel.ResourceHistory{
702+
ClusterName: cs.Name,
703+
ResourceType: kind,
704+
ResourceName: name,
705+
Namespace: namespace,
706+
OperationType: opType,
707+
OperationSource: "AI",
708+
ResourceYAML: resourceYAML,
709+
PreviousYAML: previousYAML,
710+
Success: success,
711+
ErrorMessage: errMsg,
712+
OperatorID: user.ID,
713+
}
714+
if dbErr := pkgmodel.DB.Create(&history).Error; dbErr != nil {
715+
klog.Errorf("Failed to create resource history: %v", dbErr)
716+
}
717+
}
718+
719+
func objectToYAML(obj *unstructured.Unstructured) string {
720+
if obj == nil {
721+
return ""
722+
}
723+
obj.SetManagedFields(nil)
724+
yamlBytes, err := yaml.Marshal(obj)
725+
if err != nil {
726+
return ""
727+
}
728+
return string(yamlBytes)
729+
}
730+
687731
func executeGetResource(ctx context.Context, cs *cluster.ClientSet, args map[string]interface{}) (string, bool) {
688732
kind, err := getRequiredString(args, "kind")
689733
if err != nil {
@@ -1188,35 +1232,59 @@ func executeGetClusterOverview(ctx context.Context, cs *cluster.ClientSet) (stri
11881232
return sb.String(), false
11891233
}
11901234

1191-
func executeCreateResource(ctx context.Context, cs *cluster.ClientSet, args map[string]interface{}) (string, bool) {
1235+
func executeCreateResource(ctx context.Context, cs *cluster.ClientSet, user pkgmodel.User, args map[string]interface{}) (string, bool) {
11921236
obj, err := parseResourceYAML(args)
11931237
if err != nil {
11941238
return "Error: " + err.Error(), true
11951239
}
11961240

1197-
if err := cs.K8sClient.Create(ctx, obj); err != nil {
1241+
yamlStr, _ := getRequiredString(args, "yaml")
1242+
resource := resolveResourceInfoForObject(ctx, cs, obj)
1243+
err = cs.K8sClient.Create(ctx, obj)
1244+
1245+
recordResourceHistory(cs, user, resource.Resource, obj.GetName(), obj.GetNamespace(), "create", yamlStr, "", err == nil, err)
1246+
1247+
if err != nil {
11981248
return fmt.Sprintf("Error creating %s/%s: %v", obj.GetKind(), obj.GetName(), err), true
11991249
}
12001250

12011251
klog.Infof("AI Agent created resource: %s/%s in namespace %s", obj.GetKind(), obj.GetName(), obj.GetNamespace())
12021252
return fmt.Sprintf("Successfully created %s/%s", obj.GetKind(), obj.GetName()), false
12031253
}
12041254

1205-
func executeUpdateResource(ctx context.Context, cs *cluster.ClientSet, args map[string]interface{}) (string, bool) {
1255+
func executeUpdateResource(ctx context.Context, cs *cluster.ClientSet, user pkgmodel.User, args map[string]interface{}) (string, bool) {
12061256
obj, err := parseResourceYAML(args)
12071257
if err != nil {
12081258
return "Error: " + err.Error(), true
12091259
}
12101260

1211-
if err := cs.K8sClient.Update(ctx, obj); err != nil {
1261+
yamlStr, _ := getRequiredString(args, "yaml")
1262+
1263+
// Get previous state
1264+
resource := resolveResourceInfoForObject(ctx, cs, obj)
1265+
prevObj := buildObjectForResource(resource)
1266+
key := k8stypes.NamespacedName{
1267+
Name: obj.GetName(),
1268+
Namespace: normalizeNamespace(resource, obj.GetNamespace()),
1269+
}
1270+
var previousYAML string
1271+
if getErr := cs.K8sClient.Get(ctx, key, prevObj); getErr == nil {
1272+
previousYAML = objectToYAML(prevObj)
1273+
}
1274+
1275+
err = cs.K8sClient.Update(ctx, obj)
1276+
1277+
recordResourceHistory(cs, user, resource.Resource, obj.GetName(), obj.GetNamespace(), "update", yamlStr, previousYAML, err == nil, err)
1278+
1279+
if err != nil {
12121280
return fmt.Sprintf("Error updating %s/%s: %v", obj.GetKind(), obj.GetName(), err), true
12131281
}
12141282

12151283
klog.Infof("AI Agent updated resource: %s/%s in namespace %s", obj.GetKind(), obj.GetName(), obj.GetNamespace())
12161284
return fmt.Sprintf("Successfully updated %s/%s", obj.GetKind(), obj.GetName()), false
12171285
}
12181286

1219-
func executePatchResource(ctx context.Context, cs *cluster.ClientSet, args map[string]interface{}) (string, bool) {
1287+
func executePatchResource(ctx context.Context, cs *cluster.ClientSet, user pkgmodel.User, args map[string]interface{}) (string, bool) {
12201288
kind, err := getRequiredString(args, "kind")
12211289
if err != nil {
12221290
return "Error: " + err.Error(), true
@@ -1245,17 +1313,30 @@ func executePatchResource(ctx context.Context, cs *cluster.ClientSet, args map[s
12451313
return fmt.Sprintf("Error finding %s/%s: %v", resource.Kind, name, err), true
12461314
}
12471315

1316+
// Get previous state
1317+
previousYAML := objectToYAML(obj.DeepCopy())
1318+
12481319
patchBytes := []byte(patchStr)
12491320
patch := client.RawPatch(k8stypes.StrategicMergePatchType, patchBytes)
1250-
if err := cs.K8sClient.Patch(ctx, obj, patch); err != nil {
1321+
err = cs.K8sClient.Patch(ctx, obj, patch)
1322+
1323+
// Get current state after patch
1324+
currentYAML := ""
1325+
if err == nil {
1326+
currentYAML = objectToYAML(obj)
1327+
}
1328+
1329+
recordResourceHistory(cs, user, resource.Resource, name, normalizeNamespace(resource, namespace), "patch", currentYAML, previousYAML, err == nil, err)
1330+
1331+
if err != nil {
12511332
return fmt.Sprintf("Error patching %s/%s: %v", resource.Kind, name, err), true
12521333
}
12531334

12541335
klog.Infof("AI Agent patched resource: %s/%s in namespace %s", resource.Kind, name, normalizeNamespace(resource, namespace))
12551336
return fmt.Sprintf("Successfully patched %s/%s", resource.Kind, name), false
12561337
}
12571338

1258-
func executeDeleteResource(ctx context.Context, cs *cluster.ClientSet, args map[string]interface{}) (string, bool) {
1339+
func executeDeleteResource(ctx context.Context, cs *cluster.ClientSet, user pkgmodel.User, args map[string]interface{}) (string, bool) {
12591340
kind, err := getRequiredString(args, "kind")
12601341
if err != nil {
12611342
return "Error: " + err.Error(), true
@@ -1273,14 +1354,22 @@ func executeDeleteResource(ctx context.Context, cs *cluster.ClientSet, args map[
12731354
Name: name,
12741355
Namespace: normalizeNamespace(resource, namespace),
12751356
}
1276-
if err := cs.K8sClient.Get(ctx, key, obj); err != nil {
1277-
if apierrors.IsNotFound(err) {
1357+
1358+
// Get previous state before deletion
1359+
var previousYAML string
1360+
if getErr := cs.K8sClient.Get(ctx, key, obj); getErr != nil {
1361+
if apierrors.IsNotFound(getErr) {
12781362
return fmt.Sprintf("%s/%s not found, already deleted", resource.Kind, name), false
12791363
}
1280-
return fmt.Sprintf("Error finding %s/%s: %v", resource.Kind, name, err), true
1364+
return fmt.Sprintf("Error finding %s/%s: %v", resource.Kind, name, getErr), true
12811365
}
12821366

1283-
if err := cs.K8sClient.Delete(ctx, obj); err != nil {
1367+
previousYAML = objectToYAML(obj)
1368+
err = cs.K8sClient.Delete(ctx, obj)
1369+
1370+
recordResourceHistory(cs, user, resource.Resource, name, normalizeNamespace(resource, namespace), "delete", "", previousYAML, err == nil || apierrors.IsNotFound(err), err)
1371+
1372+
if err != nil {
12841373
if apierrors.IsNotFound(err) {
12851374
return fmt.Sprintf("%s/%s not found, already deleted", resource.Kind, name), false
12861375
}

0 commit comments

Comments
 (0)