@@ -3,6 +3,7 @@ package httpd
3
3
import (
4
4
"encoding/json"
5
5
"github.com/jumpserver/koko/pkg/common"
6
+ "github.com/jumpserver/koko/pkg/proxy"
6
7
"github.com/sashabaranov/go-openai"
7
8
"sync"
8
9
"time"
@@ -40,15 +41,30 @@ func (h *chat) HandleMessage(msg *Message) {
40
41
41
42
if conversationID == "" {
42
43
id := common .UUID ()
44
+
43
45
conversation = & AIConversation {
44
46
Id : id ,
45
47
Prompt : msg .Prompt ,
46
- HistoryRecords : make ([]string , 0 ),
48
+ Context : make ([]QARecord , 0 ),
47
49
InterruptCurrentChat : false ,
48
50
}
49
51
50
- // T000 Currently a websocket connection only retains one conversation
51
- h .CleanConversationMap ()
52
+ jmsServer , err := proxy .NewChatJMSServer (
53
+ id ,
54
+ h .ws .user .String (),
55
+ h .ws .ClientIP (),
56
+ h .ws .user .ID ,
57
+ h .ws .langCode ,
58
+ h .ws .apiClient ,
59
+ h .termConf ,
60
+ )
61
+ if err != nil {
62
+ logger .Errorf ("Ws[%s] create chat jms server error: %s" , h .ws .Uuid , err )
63
+ h .sendErrorMessage (conversationID , "create chat jms server error" )
64
+ return
65
+ }
66
+
67
+ conversation .JMSServer = jmsServer
52
68
h .conversationMap .Store (id , conversation )
53
69
} else {
54
70
c , ok := h .conversationMap .Load (conversationID )
@@ -65,14 +81,15 @@ func (h *chat) HandleMessage(msg *Message) {
65
81
return
66
82
}
67
83
84
+ conversation .Question = msg .Data
85
+
68
86
openAIParam := & OpenAIParam {
69
87
AuthToken : h .termConf .GptApiKey ,
70
88
BaseURL : h .termConf .GptBaseUrl ,
71
89
Proxy : h .termConf .GptProxy ,
72
90
Model : h .termConf .GptModel ,
73
91
Prompt : conversation .Prompt ,
74
92
}
75
- conversation .HistoryRecords = append (conversation .HistoryRecords , msg .Data )
76
93
go h .chat (openAIParam , conversation )
77
94
}
78
95
@@ -90,39 +107,74 @@ func (h *chat) chat(
90
107
chatGPTParam .Proxy ,
91
108
)
92
109
93
- startIndex := len (conversation .HistoryRecords ) - 15
110
+ startIndex := len (conversation .Context ) - 8
94
111
if startIndex < 0 {
95
112
startIndex = 0
96
113
}
97
- contents := conversation .HistoryRecords [startIndex :]
114
+ conversation .Context = conversation .Context [startIndex :]
115
+ context := conversation .Context
116
+
117
+ chatContext := make ([]openai.ChatCompletionMessage , 0 )
118
+ for _ , record := range context {
119
+ chatContext = append (chatContext , openai.ChatCompletionMessage {
120
+ Role : openai .ChatMessageRoleUser ,
121
+ Content : record .Question ,
122
+ })
123
+ chatContext = append (chatContext , openai.ChatCompletionMessage {
124
+ Role : openai .ChatMessageRoleAssistant ,
125
+ Content : record .Answer ,
126
+ })
127
+ }
98
128
99
129
openAIConn := & srvconn.OpenAIConn {
100
130
Id : conversation .Id ,
101
131
Client : c ,
102
132
Prompt : chatGPTParam .Prompt ,
103
133
Model : chatGPTParam .Model ,
104
- Contents : contents ,
134
+ Question : conversation .Question ,
135
+ Context : chatContext ,
105
136
IsReasoning : false ,
106
137
AnswerCh : answerCh ,
107
138
DoneCh : doneCh ,
108
139
Type : h .termConf .ChatAIType ,
109
140
}
110
141
111
142
go openAIConn .Chat (& conversation .InterruptCurrentChat )
112
- return h .processChatMessages (openAIConn )
143
+ return h .processChatMessages (openAIConn , conversation )
113
144
}
114
145
115
146
func (h * chat ) processChatMessages (
116
- openAIConn * srvconn.OpenAIConn ,
147
+ openAIConn * srvconn.OpenAIConn , conversation * AIConversation ,
117
148
) string {
118
149
messageID := common .UUID ()
119
150
id := openAIConn .Id
151
+
152
+ go conversation .JMSServer .Replay .WriteInput (conversation .Question )
153
+
120
154
for {
121
155
select {
122
156
case answer := <- openAIConn .AnswerCh :
123
157
h .sendSessionMessage (id , answer , messageID , "message" , openAIConn .IsReasoning )
124
158
case answer := <- openAIConn .DoneCh :
125
159
h .sendSessionMessage (id , answer , messageID , "finish" , false )
160
+ runes := []rune (answer )
161
+ if len (runes ) > 100 {
162
+ answer = string (runes [:100 ])
163
+ }
164
+ conversation .Context = append (conversation .Context , QARecord {
165
+ Question : conversation .Question ,
166
+ Answer : answer ,
167
+ })
168
+
169
+ cmd := conversation .JMSServer .GenerateCommandItem (
170
+ h .ws .user .String (),
171
+ conversation .Question ,
172
+ answer ,
173
+ )
174
+
175
+ go conversation .JMSServer .CmdR .Record (cmd )
176
+ go conversation .JMSServer .Replay .WriteOutput (answer )
177
+
126
178
return answer
127
179
}
128
180
}
@@ -144,6 +196,7 @@ func (h *chat) sendSessionMessage(id, answer, messageID, messageType string, isR
144
196
Data : string (data ),
145
197
}
146
198
h .ws .SendMessage (& msg )
199
+
147
200
}
148
201
149
202
func (h * chat ) sendErrorMessage (id , message string ) {
@@ -152,12 +205,40 @@ func (h *chat) sendErrorMessage(id, message string) {
152
205
Type : "error" ,
153
206
Data : message ,
154
207
}
208
+
209
+ raw , ok := h .conversationMap .Load (id )
210
+ if ! ok {
211
+ logger .Errorf ("Ws[%s] conversation %s not found" , h .ws .Uuid , id )
212
+ return
213
+ }
214
+
215
+ // check if the type is AIConversation
216
+ conv , ok := raw .(* AIConversation )
217
+ if ! ok || conv == nil {
218
+ logger .Errorf ("Ws[%s] invalid type or nil for conversation %s" , h .ws .Uuid , id )
219
+ return
220
+ }
221
+
222
+ endConversation (conv )
155
223
h .ws .SendMessage (& msg )
156
224
}
157
225
158
226
func (h * chat ) CleanConversationMap () {
159
227
h .conversationMap .Range (func (key , value interface {}) bool {
228
+ conv , ok := value .(* AIConversation )
229
+ if ! ok || conv == nil {
230
+ logger .Errorf ("Ws[%s] invalid type or nil for conversation %v" , h .ws .Uuid , key )
231
+ } else {
232
+ endConversation (conv )
233
+ }
160
234
h .conversationMap .Delete (key )
161
235
return true
162
236
})
163
237
}
238
+
239
+ func endConversation (conv * AIConversation ) {
240
+ if conv .JMSServer != nil {
241
+ conv .JMSServer .Replay .End ()
242
+ conv .JMSServer .CmdR .End ()
243
+ }
244
+ }
0 commit comments