Skip to content

Commit 80c20c2

Browse files
committed
perf: Resurrection Kael
1 parent 473d750 commit 80c20c2

File tree

8 files changed

+266
-16
lines changed

8 files changed

+266
-16
lines changed

pkg/httpd/chat.go

+90-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package httpd
33
import (
44
"encoding/json"
55
"github.com/jumpserver/koko/pkg/common"
6+
"github.com/jumpserver/koko/pkg/proxy"
67
"github.com/sashabaranov/go-openai"
78
"sync"
89
"time"
@@ -40,15 +41,30 @@ func (h *chat) HandleMessage(msg *Message) {
4041

4142
if conversationID == "" {
4243
id := common.UUID()
44+
4345
conversation = &AIConversation{
4446
Id: id,
4547
Prompt: msg.Prompt,
46-
HistoryRecords: make([]string, 0),
48+
Context: make([]QARecord, 0),
4749
InterruptCurrentChat: false,
4850
}
4951

50-
// T000 Currently a websocket connection only retains one conversation
51-
h.CleanConversationMap()
52+
jmsServer, err := proxy.NewChatJMSServer(
53+
id,
54+
h.ws.user.String(),
55+
h.ws.ClientIP(),
56+
h.ws.user.ID,
57+
h.ws.langCode,
58+
h.ws.apiClient,
59+
h.termConf,
60+
)
61+
if err != nil {
62+
logger.Errorf("Ws[%s] create chat jms server error: %s", h.ws.Uuid, err)
63+
h.sendErrorMessage(conversationID, "create chat jms server error")
64+
return
65+
}
66+
67+
conversation.JMSServer = jmsServer
5268
h.conversationMap.Store(id, conversation)
5369
} else {
5470
c, ok := h.conversationMap.Load(conversationID)
@@ -65,14 +81,15 @@ func (h *chat) HandleMessage(msg *Message) {
6581
return
6682
}
6783

84+
conversation.Question = msg.Data
85+
6886
openAIParam := &OpenAIParam{
6987
AuthToken: h.termConf.GptApiKey,
7088
BaseURL: h.termConf.GptBaseUrl,
7189
Proxy: h.termConf.GptProxy,
7290
Model: h.termConf.GptModel,
7391
Prompt: conversation.Prompt,
7492
}
75-
conversation.HistoryRecords = append(conversation.HistoryRecords, msg.Data)
7693
go h.chat(openAIParam, conversation)
7794
}
7895

@@ -90,39 +107,74 @@ func (h *chat) chat(
90107
chatGPTParam.Proxy,
91108
)
92109

93-
startIndex := len(conversation.HistoryRecords) - 15
110+
startIndex := len(conversation.Context) - 8
94111
if startIndex < 0 {
95112
startIndex = 0
96113
}
97-
contents := conversation.HistoryRecords[startIndex:]
114+
conversation.Context = conversation.Context[startIndex:]
115+
context := conversation.Context
116+
117+
chatContext := make([]openai.ChatCompletionMessage, 0)
118+
for _, record := range context {
119+
chatContext = append(chatContext, openai.ChatCompletionMessage{
120+
Role: openai.ChatMessageRoleUser,
121+
Content: record.Question,
122+
})
123+
chatContext = append(chatContext, openai.ChatCompletionMessage{
124+
Role: openai.ChatMessageRoleAssistant,
125+
Content: record.Answer,
126+
})
127+
}
98128

99129
openAIConn := &srvconn.OpenAIConn{
100130
Id: conversation.Id,
101131
Client: c,
102132
Prompt: chatGPTParam.Prompt,
103133
Model: chatGPTParam.Model,
104-
Contents: contents,
134+
Question: conversation.Question,
135+
Context: chatContext,
105136
IsReasoning: false,
106137
AnswerCh: answerCh,
107138
DoneCh: doneCh,
108139
Type: h.termConf.ChatAIType,
109140
}
110141

111142
go openAIConn.Chat(&conversation.InterruptCurrentChat)
112-
return h.processChatMessages(openAIConn)
143+
return h.processChatMessages(openAIConn, conversation)
113144
}
114145

115146
func (h *chat) processChatMessages(
116-
openAIConn *srvconn.OpenAIConn,
147+
openAIConn *srvconn.OpenAIConn, conversation *AIConversation,
117148
) string {
118149
messageID := common.UUID()
119150
id := openAIConn.Id
151+
152+
go conversation.JMSServer.Replay.WriteInput(conversation.Question)
153+
120154
for {
121155
select {
122156
case answer := <-openAIConn.AnswerCh:
123157
h.sendSessionMessage(id, answer, messageID, "message", openAIConn.IsReasoning)
124158
case answer := <-openAIConn.DoneCh:
125159
h.sendSessionMessage(id, answer, messageID, "finish", false)
160+
runes := []rune(answer)
161+
if len(runes) > 100 {
162+
answer = string(runes[:100])
163+
}
164+
conversation.Context = append(conversation.Context, QARecord{
165+
Question: conversation.Question,
166+
Answer: answer,
167+
})
168+
169+
cmd := conversation.JMSServer.GenerateCommandItem(
170+
h.ws.user.String(),
171+
conversation.Question,
172+
answer,
173+
)
174+
175+
go conversation.JMSServer.CmdR.Record(cmd)
176+
go conversation.JMSServer.Replay.WriteOutput(answer)
177+
126178
return answer
127179
}
128180
}
@@ -144,6 +196,7 @@ func (h *chat) sendSessionMessage(id, answer, messageID, messageType string, isR
144196
Data: string(data),
145197
}
146198
h.ws.SendMessage(&msg)
199+
147200
}
148201

149202
func (h *chat) sendErrorMessage(id, message string) {
@@ -152,12 +205,40 @@ func (h *chat) sendErrorMessage(id, message string) {
152205
Type: "error",
153206
Data: message,
154207
}
208+
209+
raw, ok := h.conversationMap.Load(id)
210+
if !ok {
211+
logger.Errorf("Ws[%s] conversation %s not found", h.ws.Uuid, id)
212+
return
213+
}
214+
215+
// check if the type is AIConversation
216+
conv, ok := raw.(*AIConversation)
217+
if !ok || conv == nil {
218+
logger.Errorf("Ws[%s] invalid type or nil for conversation %s", h.ws.Uuid, id)
219+
return
220+
}
221+
222+
endConversation(conv)
155223
h.ws.SendMessage(&msg)
156224
}
157225

158226
func (h *chat) CleanConversationMap() {
159227
h.conversationMap.Range(func(key, value interface{}) bool {
228+
conv, ok := value.(*AIConversation)
229+
if !ok || conv == nil {
230+
logger.Errorf("Ws[%s] invalid type or nil for conversation %v", h.ws.Uuid, key)
231+
} else {
232+
endConversation(conv)
233+
}
160234
h.conversationMap.Delete(key)
161235
return true
162236
})
163237
}
238+
239+
func endConversation(conv *AIConversation) {
240+
if conv.JMSServer != nil {
241+
conv.JMSServer.Replay.End()
242+
conv.JMSServer.CmdR.End()
243+
}
244+
}

pkg/httpd/message.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package httpd
22

33
import (
4+
"github.com/jumpserver/koko/pkg/proxy"
45
"time"
56

67
"github.com/jumpserver/koko/pkg/exchange"
@@ -163,10 +164,17 @@ type OpenAIParam struct {
163164
Type string
164165
}
165166

167+
type QARecord struct {
168+
Question string
169+
Answer string
170+
}
171+
166172
type AIConversation struct {
167173
Id string
168174
Prompt string
169-
HistoryRecords []string
175+
Question string
176+
Context []QARecord
177+
JMSServer *proxy.ChatJMSServer
170178
InterruptCurrentChat bool
171179
}
172180

pkg/jms-sdk-go/model/account.go

+12
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ type AccountDetail struct {
5555
Privileged bool `json:"privileged"`
5656
}
5757

58+
type AssetChat struct {
59+
ID string `json:"id"`
60+
Name string `json:"name"`
61+
}
62+
63+
type AccountChatDetail struct {
64+
ID string `json:"id"`
65+
Name string `json:"name"`
66+
Username string `json:"username"`
67+
Asset AssetChat `json:"asset"`
68+
}
69+
5870
type PermAccount struct {
5971
Name string `json:"name"`
6072
Username string `json:"username"`

pkg/jms-sdk-go/service/jms_asset.go

+6
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,9 @@ func (s *JMService) GetAccountSecretById(accountId string) (res model.AccountDet
2323
_, err = s.authClient.Get(url, &res)
2424
return
2525
}
26+
27+
func (s *JMService) GetAccountChat() (res model.AccountChatDetail, err error) {
28+
url := fmt.Sprintf(AccountChatURL)
29+
_, err = s.authClient.Get(url, &res)
30+
return
31+
}

pkg/jms-sdk-go/service/url.go

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ const (
7777

7878
UserPermsAssetAccountsURL = "/api/v1/perms/users/%s/assets/%s/"
7979
AccountSecretURL = "/api/v1/assets/account-secrets/%s/"
80+
AccountChatURL = "/api/v1/accounts/accounts/chat/"
8081
UserPermsAssetsURL = "/api/v1/perms/users/%s/assets/"
8182

8283
AssetLoginConfirmURL = "/api/v1/acls/login-asset/check/"

pkg/proxy/chat.go

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package proxy
2+
3+
import (
4+
"fmt"
5+
modelCommon "github.com/jumpserver/koko/pkg/jms-sdk-go/common"
6+
"github.com/jumpserver/koko/pkg/jms-sdk-go/model"
7+
"github.com/jumpserver/koko/pkg/jms-sdk-go/service"
8+
"github.com/jumpserver/koko/pkg/logger"
9+
"strings"
10+
"time"
11+
)
12+
13+
func NewChatJMSServer(id, user, ip, userID, langCode string, jmsService *service.JMService, conf *model.TerminalConfig) (*ChatJMSServer, error) {
14+
accountInfo, err := jmsService.GetAccountChat()
15+
if err != nil {
16+
logger.Errorf("Get account chat info error: %s", err)
17+
return nil, err
18+
}
19+
20+
apiSession := &model.Session{
21+
ID: id,
22+
User: user,
23+
LoginFrom: model.LoginFromWeb,
24+
RemoteAddr: ip,
25+
Protocol: model.ActionALL,
26+
Asset: accountInfo.Asset.Name,
27+
Account: accountInfo.Name,
28+
AccountID: accountInfo.ID,
29+
AssetID: accountInfo.Asset.ID,
30+
UserID: userID,
31+
OrgID: "00000000-0000-0000-0000-000000000004",
32+
Type: model.NORMALType,
33+
LangCode: langCode,
34+
DateStart: modelCommon.NewNowUTCTime(),
35+
}
36+
37+
_, err2 := jmsService.CreateSession(*apiSession)
38+
if err2 != nil {
39+
return nil, err2
40+
}
41+
42+
cmdR := GetCommandRecorder(id, jmsService, conf)
43+
replay := GetReplayRecorder(id, jmsService, conf)
44+
45+
return &ChatJMSServer{
46+
Session: apiSession,
47+
CmdR: cmdR,
48+
Replay: replay,
49+
}, nil
50+
}
51+
52+
type ChatReplyRecorder struct {
53+
*ReplyRecorder
54+
}
55+
56+
func (rh *ChatReplyRecorder) WriteInput(inputStr string) {
57+
currentTime := time.Now()
58+
formattedTime := currentTime.Format("2006-01-02 15:04:05")
59+
inputStr = fmt.Sprintf("[%s]#: %s", formattedTime, inputStr)
60+
rh.Record([]byte(inputStr))
61+
}
62+
63+
func (rh *ChatReplyRecorder) WriteOutput(outputStr string) {
64+
wrappedText := rh.wrapText(outputStr)
65+
outputStr = "\r\n" + wrappedText + "\r\n"
66+
rh.Record([]byte(outputStr))
67+
68+
}
69+
70+
func (rh *ChatReplyRecorder) wrapText(text string) string {
71+
var wrappedTextBuilder strings.Builder
72+
words := strings.Fields(text)
73+
currentLineLength := 0
74+
75+
for _, word := range words {
76+
wordLength := len(word)
77+
78+
if currentLineLength+wordLength > rh.Writer.Width {
79+
wrappedTextBuilder.WriteString("\r\n" + word + " ")
80+
currentLineLength = wordLength + 1
81+
} else {
82+
wrappedTextBuilder.WriteString(word + " ")
83+
currentLineLength += wordLength + 1
84+
}
85+
}
86+
87+
return wrappedTextBuilder.String()
88+
}
89+
90+
type ChatJMSServer struct {
91+
Session *model.Session
92+
CmdR *CommandRecorder
93+
Replay *ChatReplyRecorder
94+
}
95+
96+
func (s *ChatJMSServer) GenerateCommandItem(user, input, output string) *model.Command {
97+
createdDate := time.Now()
98+
return &model.Command{
99+
SessionID: s.Session.ID,
100+
OrgID: s.Session.OrgID,
101+
Input: input,
102+
Output: output,
103+
User: user,
104+
Server: s.Session.Asset,
105+
Account: s.Session.Account,
106+
Timestamp: createdDate.Unix(),
107+
RiskLevel: model.NormalLevel,
108+
DateCreated: createdDate.UTC(),
109+
}
110+
}
111+
112+
func GetReplayRecorder(id string, jmsService *service.JMService, conf *model.TerminalConfig) *ChatReplyRecorder {
113+
info := &ReplyInfo{
114+
Width: 200,
115+
Height: 200,
116+
TimeStamp: time.Now(),
117+
}
118+
recorder, err := NewReplayRecord(id, jmsService,
119+
NewReplayStorage(jmsService, conf),
120+
info)
121+
if err != nil {
122+
logger.Error(err)
123+
}
124+
125+
return &ChatReplyRecorder{recorder}
126+
}
127+
128+
func GetCommandRecorder(id string, jmsService *service.JMService, conf *model.TerminalConfig) *CommandRecorder {
129+
cmdR := CommandRecorder{
130+
sessionID: id,
131+
storage: NewCommandStorage(jmsService, conf),
132+
queue: make(chan *model.Command, 10),
133+
closed: make(chan struct{}),
134+
jmsService: jmsService,
135+
}
136+
go cmdR.record()
137+
return &cmdR
138+
}

0 commit comments

Comments
 (0)