Skip to content

Commit a7daa31

Browse files
committed
perf: Resurrection Kael
1 parent 7c1a4fc commit a7daa31

File tree

8 files changed

+401
-109
lines changed

8 files changed

+401
-109
lines changed

pkg/httpd/chat.go

+191-101
Original file line numberDiff line numberDiff line change
@@ -1,163 +1,253 @@
11
package httpd
22

33
import (
4+
"context"
45
"encoding/json"
6+
"fmt"
57
"github.com/jumpserver/koko/pkg/common"
8+
"github.com/jumpserver/koko/pkg/i18n"
9+
"github.com/jumpserver/koko/pkg/logger"
10+
"github.com/jumpserver/koko/pkg/proxy"
11+
"github.com/jumpserver/koko/pkg/session"
612
"github.com/sashabaranov/go-openai"
713
"sync"
814
"time"
915

1016
"github.com/jumpserver/koko/pkg/jms-sdk-go/model"
11-
"github.com/jumpserver/koko/pkg/logger"
1217
"github.com/jumpserver/koko/pkg/srvconn"
1318
)
1419

1520
var _ Handler = (*chat)(nil)
1621

1722
type chat struct {
18-
ws *UserWebsocket
23+
ws *UserWebsocket
24+
term *model.TerminalConfig
1925

20-
conversationMap sync.Map
21-
22-
termConf *model.TerminalConfig
26+
// conversationMap: map[conversationID]*AIConversation
27+
conversations sync.Map
2328
}
2429

2530
func (h *chat) Name() string {
2631
return ChatName
2732
}
2833

29-
func (h *chat) CleanUp() {
30-
h.CleanConversationMap()
31-
}
34+
func (h *chat) CleanUp() { h.cleanupAll() }
3235

3336
func (h *chat) CheckValidation() error {
3437
return nil
3538
}
3639

3740
func (h *chat) HandleMessage(msg *Message) {
38-
conversationID := msg.Id
39-
conversation := &AIConversation{}
40-
41-
if conversationID == "" {
42-
id := common.UUID()
43-
conversation = &AIConversation{
44-
Id: id,
45-
Prompt: msg.Prompt,
46-
HistoryRecords: make([]string, 0),
47-
InterruptCurrentChat: false,
48-
}
41+
if msg.Interrupt {
42+
h.interrupt(msg.Id)
43+
return
44+
}
4945

50-
// T000 Currently a websocket connection only retains one conversation
51-
h.CleanConversationMap()
52-
h.conversationMap.Store(id, conversation)
53-
} else {
54-
c, ok := h.conversationMap.Load(conversationID)
55-
if !ok {
56-
logger.Errorf("Ws[%s] conversation %s not found", h.ws.Uuid, conversationID)
57-
h.sendErrorMessage(conversationID, "conversation not found")
58-
return
46+
conv, err := h.getOrCreateConversation(msg)
47+
if err != nil {
48+
h.sendError(msg.Id, err.Error())
49+
return
50+
}
51+
conv.Question = msg.Data
52+
conv.NewDialogue = true
53+
54+
go h.runChat(conv)
55+
}
56+
57+
func (h *chat) getOrCreateConversation(msg *Message) (*AIConversation, error) {
58+
if msg.Id != "" {
59+
if v, ok := h.conversations.Load(msg.Id); ok {
60+
return v.(*AIConversation), nil
5961
}
60-
conversation = c.(*AIConversation)
62+
return nil, fmt.Errorf("conversation %s not found", msg.Id)
6163
}
6264

63-
if msg.Interrupt {
64-
conversation.InterruptCurrentChat = true
65-
return
65+
jmsSrv, err := proxy.NewChatJMSServer(
66+
h.ws.user.String(), h.ws.ClientIP(),
67+
h.ws.user.ID, h.ws.langCode, h.ws.apiClient, h.term,
68+
)
69+
if err != nil {
70+
return nil, fmt.Errorf("create JMS server: %w", err)
6671
}
6772

68-
openAIParam := &OpenAIParam{
69-
AuthToken: h.termConf.GptApiKey,
70-
BaseURL: h.termConf.GptBaseUrl,
71-
Proxy: h.termConf.GptProxy,
72-
Model: h.termConf.GptModel,
73-
Prompt: conversation.Prompt,
73+
sess := session.NewSession(jmsSrv.Session, h.sessionCallback)
74+
session.AddSession(sess)
75+
76+
conv := &AIConversation{
77+
Id: jmsSrv.Session.ID,
78+
Prompt: msg.Prompt,
79+
Context: make([]QARecord, 0),
80+
JMSServer: jmsSrv,
7481
}
75-
conversation.HistoryRecords = append(conversation.HistoryRecords, msg.Data)
76-
go h.chat(openAIParam, conversation)
77-
}
78-
79-
func (h *chat) chat(
80-
chatGPTParam *OpenAIParam, conversation *AIConversation,
81-
) string {
82-
doneCh := make(chan string)
83-
answerCh := make(chan string)
84-
defer close(doneCh)
85-
defer close(answerCh)
86-
87-
c := srvconn.NewOpenAIClient(
88-
chatGPTParam.AuthToken,
89-
chatGPTParam.BaseURL,
90-
chatGPTParam.Proxy,
82+
h.conversations.Store(jmsSrv.Session.ID, conv)
83+
go h.Monitor(conv)
84+
return conv, nil
85+
}
86+
87+
func (h *chat) sessionCallback(task *model.TerminalTask) error {
88+
if task.Name == model.TaskKillSession {
89+
h.endConversation(task.Args, "close", "kill session")
90+
return nil
91+
}
92+
return fmt.Errorf("unknown session task %s", task.Name)
93+
}
94+
95+
func (h *chat) runChat(conv *AIConversation) {
96+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
97+
defer cancel()
98+
99+
client := srvconn.NewOpenAIClient(
100+
h.term.GptApiKey, h.term.GptBaseUrl, h.term.GptProxy,
91101
)
92102

93-
startIndex := len(conversation.HistoryRecords) - 15
94-
if startIndex < 0 {
95-
startIndex = 0
103+
// Keep the last 8 contexts
104+
if len(conv.Context) > 8 {
105+
conv.Context = conv.Context[len(conv.Context)-8:]
96106
}
97-
contents := conversation.HistoryRecords[startIndex:]
98-
99-
openAIConn := &srvconn.OpenAIConn{
100-
Id: conversation.Id,
101-
Client: c,
102-
Prompt: chatGPTParam.Prompt,
103-
Model: chatGPTParam.Model,
104-
Contents: contents,
107+
messages := buildChatMessages(conv)
108+
109+
conn := &srvconn.OpenAIConn{
110+
Id: conv.Id,
111+
Client: client,
112+
Prompt: conv.Prompt,
113+
Model: h.term.GptModel,
114+
Question: conv.Question,
115+
Context: messages,
116+
AnswerCh: make(chan string),
117+
DoneCh: make(chan string),
105118
IsReasoning: false,
106-
AnswerCh: answerCh,
107-
DoneCh: doneCh,
108-
Type: h.termConf.ChatAIType,
119+
Type: h.term.ChatAIType,
109120
}
110121

111-
go openAIConn.Chat(&conversation.InterruptCurrentChat)
112-
return h.processChatMessages(openAIConn)
122+
// 启动 streaming
123+
go conn.Chat(&conv.InterruptCurrentChat)
124+
125+
conv.JMSServer.Replay.WriteInput(conv.Question)
126+
127+
h.streamResponses(ctx, conv, conn)
128+
}
129+
130+
func buildChatMessages(conv *AIConversation) []openai.ChatCompletionMessage {
131+
msgs := make([]openai.ChatCompletionMessage, 0, len(conv.Context)*2)
132+
for _, r := range conv.Context {
133+
msgs = append(msgs,
134+
openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: r.Question},
135+
openai.ChatCompletionMessage{Role: openai.ChatMessageRoleAssistant, Content: r.Answer},
136+
)
137+
}
138+
return msgs
113139
}
114140

115-
func (h *chat) processChatMessages(
116-
openAIConn *srvconn.OpenAIConn,
117-
) string {
118-
messageID := common.UUID()
119-
id := openAIConn.Id
141+
func (h *chat) streamResponses(
142+
ctx context.Context, conv *AIConversation, conn *srvconn.OpenAIConn,
143+
) {
144+
msgID := common.UUID()
120145
for {
121146
select {
122-
case answer := <-openAIConn.AnswerCh:
123-
h.sendSessionMessage(id, answer, messageID, "message", openAIConn.IsReasoning)
124-
case answer := <-openAIConn.DoneCh:
125-
h.sendSessionMessage(id, answer, messageID, "finish", false)
126-
return answer
147+
case <-ctx.Done():
148+
h.sendError(conv.Id, "chat timeout")
149+
return
150+
case ans := <-conn.AnswerCh:
151+
h.sendMessage(conv.Id, msgID, ans, "message", conn.IsReasoning)
152+
case ans := <-conn.DoneCh:
153+
h.sendMessage(conv.Id, msgID, ans, "finish", false)
154+
h.finalizeConversation(conv, ans)
155+
return
127156
}
128157
}
129158
}
130159

131-
func (h *chat) sendSessionMessage(id, answer, messageID, messageType string, isReasoning bool) {
132-
message := ChatGPTMessage{
133-
Content: answer,
134-
ID: messageID,
160+
func (h *chat) finalizeConversation(conv *AIConversation, fullAnswer string) {
161+
runes := []rune(fullAnswer)
162+
snippet := fullAnswer
163+
if len(runes) > 100 {
164+
snippet = string(runes[:100])
165+
}
166+
conv.Context = append(conv.Context, QARecord{Question: conv.Question, Answer: snippet})
167+
168+
cmd := conv.JMSServer.GenerateCommandItem(h.ws.user.String(), conv.Question, fullAnswer)
169+
go conv.JMSServer.CmdR.Record(cmd)
170+
go conv.JMSServer.Replay.WriteOutput(fullAnswer)
171+
}
172+
173+
func (h *chat) sendMessage(
174+
convID, msgID, content, typ string, reasoning bool,
175+
) {
176+
msg := ChatGPTMessage{
177+
Content: content,
178+
ID: msgID,
135179
CreateTime: time.Now(),
136-
Type: messageType,
180+
Type: typ,
137181
Role: openai.ChatMessageRoleAssistant,
138-
IsReasoning: isReasoning,
182+
IsReasoning: reasoning,
139183
}
140-
data, _ := json.Marshal(message)
141-
msg := Message{
142-
Id: id,
143-
Type: "message",
144-
Data: string(data),
184+
data, _ := json.Marshal(msg)
185+
h.ws.SendMessage(&Message{Id: convID, Type: "message", Data: string(data)})
186+
}
187+
188+
func (h *chat) sendError(convID, errMsg string) {
189+
h.endConversation(convID, "error", errMsg)
190+
}
191+
192+
func (h *chat) endConversation(convID, typ, msg string) {
193+
194+
defer func() {
195+
if r := recover(); r != nil {
196+
logger.Errorf("panic while sending message to session %s: %v", convID, r)
197+
}
198+
}()
199+
200+
if v, ok := h.conversations.Load(convID); ok {
201+
if conv, ok2 := v.(*AIConversation); ok2 && conv.JMSServer != nil {
202+
conv.JMSServer.Close(msg)
203+
}
145204
}
146-
h.ws.SendMessage(&msg)
205+
h.conversations.Delete(convID)
206+
h.ws.SendMessage(&Message{Id: convID, Type: typ, Data: msg})
147207
}
148208

149-
func (h *chat) sendErrorMessage(id, message string) {
150-
msg := Message{
151-
Id: id,
152-
Type: "error",
153-
Data: message,
209+
func (h *chat) interrupt(convID string) {
210+
if v, ok := h.conversations.Load(convID); ok {
211+
v.(*AIConversation).InterruptCurrentChat = true
154212
}
155-
h.ws.SendMessage(&msg)
156213
}
157214

158-
func (h *chat) CleanConversationMap() {
159-
h.conversationMap.Range(func(key, value interface{}) bool {
160-
h.conversationMap.Delete(key)
215+
func (h *chat) cleanupAll() {
216+
h.conversations.Range(func(key, _ interface{}) bool {
217+
h.endConversation(key.(string), "close", "")
161218
return true
162219
})
163220
}
221+
222+
func (h *chat) Monitor(conv *AIConversation) {
223+
lang := i18n.NewLang(h.ws.langCode)
224+
225+
lastActiveTime := time.Now()
226+
maxIdleTime := time.Duration(h.term.MaxIdleTime) * time.Minute
227+
MaxSessionTime := time.Now().Add(time.Duration(h.term.MaxSessionTime) * time.Hour)
228+
229+
for {
230+
now := time.Now()
231+
if MaxSessionTime.Before(now) {
232+
msg := lang.T("Session max time reached, disconnect")
233+
logger.Infof("Session[%s] max session time reached, disconnect", conv.Id)
234+
h.endConversation(conv.Id, "close", msg)
235+
return
236+
}
237+
238+
outTime := lastActiveTime.Add(maxIdleTime)
239+
if now.After(outTime) {
240+
msg := fmt.Sprintf(lang.T("Connect idle more than %d minutes, disconnect"), h.term.MaxIdleTime)
241+
logger.Infof("Session[%s] idle more than %d minutes, disconnect", conv.Id, h.term.MaxIdleTime)
242+
h.endConversation(conv.Id, "close", msg)
243+
return
244+
}
245+
246+
if conv.NewDialogue {
247+
lastActiveTime = time.Now()
248+
conv.NewDialogue = false
249+
}
250+
251+
time.Sleep(10 * time.Second)
252+
}
253+
}

pkg/httpd/message.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package httpd
22

33
import (
4+
"github.com/jumpserver/koko/pkg/proxy"
45
"time"
56

67
"github.com/jumpserver/koko/pkg/exchange"
@@ -163,11 +164,19 @@ type OpenAIParam struct {
163164
Type string
164165
}
165166

167+
type QARecord struct {
168+
Question string
169+
Answer string
170+
}
171+
166172
type AIConversation struct {
167173
Id string
168174
Prompt string
169-
HistoryRecords []string
175+
Question string
176+
Context []QARecord
177+
JMSServer *proxy.ChatJMSServer
170178
InterruptCurrentChat bool
179+
NewDialogue bool
171180
}
172181

173182
type ChatGPTMessage struct {

pkg/httpd/webserver.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,9 @@ func (s *Server) ChatAIWebsocket(ctx *gin.Context) {
158158
}
159159

160160
userConn.handler = &chat{
161-
ws: userConn,
162-
conversationMap: sync.Map{},
163-
termConf: &termConf,
161+
ws: userConn,
162+
conversations: sync.Map{},
163+
term: &termConf,
164164
}
165165
s.broadCaster.EnterUserWebsocket(userConn)
166166
defer s.broadCaster.LeaveUserWebsocket(userConn)

0 commit comments

Comments
 (0)