|
1 | 1 | package httpd
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
4 | 5 | "encoding/json"
|
| 6 | + "fmt" |
5 | 7 | "github.com/jumpserver/koko/pkg/common"
|
| 8 | + "github.com/jumpserver/koko/pkg/i18n" |
| 9 | + "github.com/jumpserver/koko/pkg/logger" |
| 10 | + "github.com/jumpserver/koko/pkg/proxy" |
| 11 | + "github.com/jumpserver/koko/pkg/session" |
6 | 12 | "github.com/sashabaranov/go-openai"
|
7 | 13 | "sync"
|
8 | 14 | "time"
|
9 | 15 |
|
10 | 16 | "github.com/jumpserver/koko/pkg/jms-sdk-go/model"
|
11 |
| - "github.com/jumpserver/koko/pkg/logger" |
12 | 17 | "github.com/jumpserver/koko/pkg/srvconn"
|
13 | 18 | )
|
14 | 19 |
|
15 | 20 | var _ Handler = (*chat)(nil)
|
16 | 21 |
|
17 | 22 | type chat struct {
|
18 |
| - ws *UserWebsocket |
| 23 | + ws *UserWebsocket |
| 24 | + term *model.TerminalConfig |
19 | 25 |
|
20 |
| - conversationMap sync.Map |
21 |
| - |
22 |
| - termConf *model.TerminalConfig |
| 26 | + // conversationMap: map[conversationID]*AIConversation |
| 27 | + conversations sync.Map |
23 | 28 | }
|
24 | 29 |
|
25 | 30 | func (h *chat) Name() string {
|
26 | 31 | return ChatName
|
27 | 32 | }
|
28 | 33 |
|
29 |
| -func (h *chat) CleanUp() { |
30 |
| - h.CleanConversationMap() |
31 |
| -} |
| 34 | +func (h *chat) CleanUp() { h.cleanupAll() } |
32 | 35 |
|
33 | 36 | func (h *chat) CheckValidation() error {
|
34 | 37 | return nil
|
35 | 38 | }
|
36 | 39 |
|
37 | 40 | func (h *chat) HandleMessage(msg *Message) {
|
38 |
| - conversationID := msg.Id |
39 |
| - conversation := &AIConversation{} |
40 |
| - |
41 |
| - if conversationID == "" { |
42 |
| - id := common.UUID() |
43 |
| - conversation = &AIConversation{ |
44 |
| - Id: id, |
45 |
| - Prompt: msg.Prompt, |
46 |
| - HistoryRecords: make([]string, 0), |
47 |
| - InterruptCurrentChat: false, |
48 |
| - } |
| 41 | + if msg.Interrupt { |
| 42 | + h.interrupt(msg.Id) |
| 43 | + return |
| 44 | + } |
49 | 45 |
|
50 |
| - // T000 Currently a websocket connection only retains one conversation |
51 |
| - h.CleanConversationMap() |
52 |
| - h.conversationMap.Store(id, conversation) |
53 |
| - } else { |
54 |
| - c, ok := h.conversationMap.Load(conversationID) |
55 |
| - if !ok { |
56 |
| - logger.Errorf("Ws[%s] conversation %s not found", h.ws.Uuid, conversationID) |
57 |
| - h.sendErrorMessage(conversationID, "conversation not found") |
58 |
| - return |
| 46 | + conv, err := h.getOrCreateConversation(msg) |
| 47 | + if err != nil { |
| 48 | + h.sendError(msg.Id, err.Error()) |
| 49 | + return |
| 50 | + } |
| 51 | + conv.Question = msg.Data |
| 52 | + conv.NewDialogue = true |
| 53 | + |
| 54 | + go h.runChat(conv) |
| 55 | +} |
| 56 | + |
| 57 | +func (h *chat) getOrCreateConversation(msg *Message) (*AIConversation, error) { |
| 58 | + if msg.Id != "" { |
| 59 | + if v, ok := h.conversations.Load(msg.Id); ok { |
| 60 | + return v.(*AIConversation), nil |
59 | 61 | }
|
60 |
| - conversation = c.(*AIConversation) |
| 62 | + return nil, fmt.Errorf("conversation %s not found", msg.Id) |
61 | 63 | }
|
62 | 64 |
|
63 |
| - if msg.Interrupt { |
64 |
| - conversation.InterruptCurrentChat = true |
65 |
| - return |
| 65 | + jmsSrv, err := proxy.NewChatJMSServer( |
| 66 | + h.ws.user.String(), h.ws.ClientIP(), |
| 67 | + h.ws.user.ID, h.ws.langCode, h.ws.apiClient, h.term, |
| 68 | + ) |
| 69 | + if err != nil { |
| 70 | + return nil, fmt.Errorf("create JMS server: %w", err) |
66 | 71 | }
|
67 | 72 |
|
68 |
| - openAIParam := &OpenAIParam{ |
69 |
| - AuthToken: h.termConf.GptApiKey, |
70 |
| - BaseURL: h.termConf.GptBaseUrl, |
71 |
| - Proxy: h.termConf.GptProxy, |
72 |
| - Model: h.termConf.GptModel, |
73 |
| - Prompt: conversation.Prompt, |
| 73 | + sess := session.NewSession(jmsSrv.Session, h.sessionCallback) |
| 74 | + session.AddSession(sess) |
| 75 | + |
| 76 | + conv := &AIConversation{ |
| 77 | + Id: jmsSrv.Session.ID, |
| 78 | + Prompt: msg.Prompt, |
| 79 | + Context: make([]QARecord, 0), |
| 80 | + JMSServer: jmsSrv, |
74 | 81 | }
|
75 |
| - conversation.HistoryRecords = append(conversation.HistoryRecords, msg.Data) |
76 |
| - go h.chat(openAIParam, conversation) |
77 |
| -} |
78 |
| - |
79 |
| -func (h *chat) chat( |
80 |
| - chatGPTParam *OpenAIParam, conversation *AIConversation, |
81 |
| -) string { |
82 |
| - doneCh := make(chan string) |
83 |
| - answerCh := make(chan string) |
84 |
| - defer close(doneCh) |
85 |
| - defer close(answerCh) |
86 |
| - |
87 |
| - c := srvconn.NewOpenAIClient( |
88 |
| - chatGPTParam.AuthToken, |
89 |
| - chatGPTParam.BaseURL, |
90 |
| - chatGPTParam.Proxy, |
| 82 | + h.conversations.Store(jmsSrv.Session.ID, conv) |
| 83 | + go h.Monitor(conv) |
| 84 | + return conv, nil |
| 85 | +} |
| 86 | + |
| 87 | +func (h *chat) sessionCallback(task *model.TerminalTask) error { |
| 88 | + if task.Name == model.TaskKillSession { |
| 89 | + h.endConversation(task.Args, "close", "kill session") |
| 90 | + return nil |
| 91 | + } |
| 92 | + return fmt.Errorf("unknown session task %s", task.Name) |
| 93 | +} |
| 94 | + |
| 95 | +func (h *chat) runChat(conv *AIConversation) { |
| 96 | + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) |
| 97 | + defer cancel() |
| 98 | + |
| 99 | + client := srvconn.NewOpenAIClient( |
| 100 | + h.term.GptApiKey, h.term.GptBaseUrl, h.term.GptProxy, |
91 | 101 | )
|
92 | 102 |
|
93 |
| - startIndex := len(conversation.HistoryRecords) - 15 |
94 |
| - if startIndex < 0 { |
95 |
| - startIndex = 0 |
| 103 | + // Keep the last 8 contexts |
| 104 | + if len(conv.Context) > 8 { |
| 105 | + conv.Context = conv.Context[len(conv.Context)-8:] |
96 | 106 | }
|
97 |
| - contents := conversation.HistoryRecords[startIndex:] |
98 |
| - |
99 |
| - openAIConn := &srvconn.OpenAIConn{ |
100 |
| - Id: conversation.Id, |
101 |
| - Client: c, |
102 |
| - Prompt: chatGPTParam.Prompt, |
103 |
| - Model: chatGPTParam.Model, |
104 |
| - Contents: contents, |
| 107 | + messages := buildChatMessages(conv) |
| 108 | + |
| 109 | + conn := &srvconn.OpenAIConn{ |
| 110 | + Id: conv.Id, |
| 111 | + Client: client, |
| 112 | + Prompt: conv.Prompt, |
| 113 | + Model: h.term.GptModel, |
| 114 | + Question: conv.Question, |
| 115 | + Context: messages, |
| 116 | + AnswerCh: make(chan string), |
| 117 | + DoneCh: make(chan string), |
105 | 118 | IsReasoning: false,
|
106 |
| - AnswerCh: answerCh, |
107 |
| - DoneCh: doneCh, |
108 |
| - Type: h.termConf.ChatAIType, |
| 119 | + Type: h.term.ChatAIType, |
109 | 120 | }
|
110 | 121 |
|
111 |
| - go openAIConn.Chat(&conversation.InterruptCurrentChat) |
112 |
| - return h.processChatMessages(openAIConn) |
| 122 | + // 启动 streaming |
| 123 | + go conn.Chat(&conv.InterruptCurrentChat) |
| 124 | + |
| 125 | + conv.JMSServer.Replay.WriteInput(conv.Question) |
| 126 | + |
| 127 | + h.streamResponses(ctx, conv, conn) |
| 128 | +} |
| 129 | + |
| 130 | +func buildChatMessages(conv *AIConversation) []openai.ChatCompletionMessage { |
| 131 | + msgs := make([]openai.ChatCompletionMessage, 0, len(conv.Context)*2) |
| 132 | + for _, r := range conv.Context { |
| 133 | + msgs = append(msgs, |
| 134 | + openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: r.Question}, |
| 135 | + openai.ChatCompletionMessage{Role: openai.ChatMessageRoleAssistant, Content: r.Answer}, |
| 136 | + ) |
| 137 | + } |
| 138 | + return msgs |
113 | 139 | }
|
114 | 140 |
|
115 |
| -func (h *chat) processChatMessages( |
116 |
| - openAIConn *srvconn.OpenAIConn, |
117 |
| -) string { |
118 |
| - messageID := common.UUID() |
119 |
| - id := openAIConn.Id |
| 141 | +func (h *chat) streamResponses( |
| 142 | + ctx context.Context, conv *AIConversation, conn *srvconn.OpenAIConn, |
| 143 | +) { |
| 144 | + msgID := common.UUID() |
120 | 145 | for {
|
121 | 146 | select {
|
122 |
| - case answer := <-openAIConn.AnswerCh: |
123 |
| - h.sendSessionMessage(id, answer, messageID, "message", openAIConn.IsReasoning) |
124 |
| - case answer := <-openAIConn.DoneCh: |
125 |
| - h.sendSessionMessage(id, answer, messageID, "finish", false) |
126 |
| - return answer |
| 147 | + case <-ctx.Done(): |
| 148 | + h.sendError(conv.Id, "chat timeout") |
| 149 | + return |
| 150 | + case ans := <-conn.AnswerCh: |
| 151 | + h.sendMessage(conv.Id, msgID, ans, "message", conn.IsReasoning) |
| 152 | + case ans := <-conn.DoneCh: |
| 153 | + h.sendMessage(conv.Id, msgID, ans, "finish", false) |
| 154 | + h.finalizeConversation(conv, ans) |
| 155 | + return |
127 | 156 | }
|
128 | 157 | }
|
129 | 158 | }
|
130 | 159 |
|
131 |
| -func (h *chat) sendSessionMessage(id, answer, messageID, messageType string, isReasoning bool) { |
132 |
| - message := ChatGPTMessage{ |
133 |
| - Content: answer, |
134 |
| - ID: messageID, |
| 160 | +func (h *chat) finalizeConversation(conv *AIConversation, fullAnswer string) { |
| 161 | + runes := []rune(fullAnswer) |
| 162 | + snippet := fullAnswer |
| 163 | + if len(runes) > 100 { |
| 164 | + snippet = string(runes[:100]) |
| 165 | + } |
| 166 | + conv.Context = append(conv.Context, QARecord{Question: conv.Question, Answer: snippet}) |
| 167 | + |
| 168 | + cmd := conv.JMSServer.GenerateCommandItem(h.ws.user.String(), conv.Question, fullAnswer) |
| 169 | + go conv.JMSServer.CmdR.Record(cmd) |
| 170 | + go conv.JMSServer.Replay.WriteOutput(fullAnswer) |
| 171 | +} |
| 172 | + |
| 173 | +func (h *chat) sendMessage( |
| 174 | + convID, msgID, content, typ string, reasoning bool, |
| 175 | +) { |
| 176 | + msg := ChatGPTMessage{ |
| 177 | + Content: content, |
| 178 | + ID: msgID, |
135 | 179 | CreateTime: time.Now(),
|
136 |
| - Type: messageType, |
| 180 | + Type: typ, |
137 | 181 | Role: openai.ChatMessageRoleAssistant,
|
138 |
| - IsReasoning: isReasoning, |
| 182 | + IsReasoning: reasoning, |
139 | 183 | }
|
140 |
| - data, _ := json.Marshal(message) |
141 |
| - msg := Message{ |
142 |
| - Id: id, |
143 |
| - Type: "message", |
144 |
| - Data: string(data), |
| 184 | + data, _ := json.Marshal(msg) |
| 185 | + h.ws.SendMessage(&Message{Id: convID, Type: "message", Data: string(data)}) |
| 186 | +} |
| 187 | + |
| 188 | +func (h *chat) sendError(convID, errMsg string) { |
| 189 | + h.endConversation(convID, "error", errMsg) |
| 190 | +} |
| 191 | + |
| 192 | +func (h *chat) endConversation(convID, typ, msg string) { |
| 193 | + |
| 194 | + defer func() { |
| 195 | + if r := recover(); r != nil { |
| 196 | + logger.Errorf("panic while sending message to session %s: %v", convID, r) |
| 197 | + } |
| 198 | + }() |
| 199 | + |
| 200 | + if v, ok := h.conversations.Load(convID); ok { |
| 201 | + if conv, ok2 := v.(*AIConversation); ok2 && conv.JMSServer != nil { |
| 202 | + conv.JMSServer.Close(msg) |
| 203 | + } |
145 | 204 | }
|
146 |
| - h.ws.SendMessage(&msg) |
| 205 | + h.conversations.Delete(convID) |
| 206 | + h.ws.SendMessage(&Message{Id: convID, Type: typ, Data: msg}) |
147 | 207 | }
|
148 | 208 |
|
149 |
| -func (h *chat) sendErrorMessage(id, message string) { |
150 |
| - msg := Message{ |
151 |
| - Id: id, |
152 |
| - Type: "error", |
153 |
| - Data: message, |
| 209 | +func (h *chat) interrupt(convID string) { |
| 210 | + if v, ok := h.conversations.Load(convID); ok { |
| 211 | + v.(*AIConversation).InterruptCurrentChat = true |
154 | 212 | }
|
155 |
| - h.ws.SendMessage(&msg) |
156 | 213 | }
|
157 | 214 |
|
158 |
| -func (h *chat) CleanConversationMap() { |
159 |
| - h.conversationMap.Range(func(key, value interface{}) bool { |
160 |
| - h.conversationMap.Delete(key) |
| 215 | +func (h *chat) cleanupAll() { |
| 216 | + h.conversations.Range(func(key, _ interface{}) bool { |
| 217 | + h.endConversation(key.(string), "close", "") |
161 | 218 | return true
|
162 | 219 | })
|
163 | 220 | }
|
| 221 | + |
| 222 | +func (h *chat) Monitor(conv *AIConversation) { |
| 223 | + lang := i18n.NewLang(h.ws.langCode) |
| 224 | + |
| 225 | + lastActiveTime := time.Now() |
| 226 | + maxIdleTime := time.Duration(h.term.MaxIdleTime) * time.Minute |
| 227 | + MaxSessionTime := time.Now().Add(time.Duration(h.term.MaxSessionTime) * time.Hour) |
| 228 | + |
| 229 | + for { |
| 230 | + now := time.Now() |
| 231 | + if MaxSessionTime.Before(now) { |
| 232 | + msg := lang.T("Session max time reached, disconnect") |
| 233 | + logger.Infof("Session[%s] max session time reached, disconnect", conv.Id) |
| 234 | + h.endConversation(conv.Id, "close", msg) |
| 235 | + return |
| 236 | + } |
| 237 | + |
| 238 | + outTime := lastActiveTime.Add(maxIdleTime) |
| 239 | + if now.After(outTime) { |
| 240 | + msg := fmt.Sprintf(lang.T("Connect idle more than %d minutes, disconnect"), h.term.MaxIdleTime) |
| 241 | + logger.Infof("Session[%s] idle more than %d minutes, disconnect", conv.Id, h.term.MaxIdleTime) |
| 242 | + h.endConversation(conv.Id, "close", msg) |
| 243 | + return |
| 244 | + } |
| 245 | + |
| 246 | + if conv.NewDialogue { |
| 247 | + lastActiveTime = time.Now() |
| 248 | + conv.NewDialogue = false |
| 249 | + } |
| 250 | + |
| 251 | + time.Sleep(10 * time.Second) |
| 252 | + } |
| 253 | +} |
0 commit comments