1515package model
1616
1717import (
18+ "context"
1819 "fmt"
1920 "io"
2021 "net/http"
2122 "strings"
2223
23- iflytek "github.com/vogo/xfspark/chat"
24+ "github.com/iflytek/spark-ai-go/sparkai/llms/spark"
25+ "github.com/iflytek/spark-ai-go/sparkai/llms/spark/client/sparkclient"
26+ "github.com/iflytek/spark-ai-go/sparkai/messages"
2427)
2528
2629type iFlytekModelProvider struct {
2730 subType string
2831 appID string
2932 apiKey string
3033 secretKey string
31- temperature string
34+ temperature float32
3235 topK int
3336}
3437
35- func NewiFlytekModelProvider (subType string , secretKey string , temperature float32 , topK int ) (* iFlytekModelProvider , error ) {
38+ func NewiFlytekModelProvider (subType string , secretKey string , apiKey string , appId string , temperature float32 , topK int ) (* iFlytekModelProvider , error ) {
3639 p := & iFlytekModelProvider {
3740 subType : subType ,
38- appID : "" ,
39- apiKey : "" ,
41+ appID : appId ,
42+ apiKey : apiKey ,
4043 secretKey : secretKey ,
41- temperature : fmt . Sprintf ( "%f" , temperature ) ,
44+ temperature : temperature ,
4245 topK : topK ,
4346 }
4447 return p , nil
@@ -101,7 +104,7 @@ func (p *iFlytekModelProvider) calculatePrice(modelResult *ModelResult) error {
101104 modelResult .Currency = "CNY"
102105
103106 switch p .subType {
104- case "spark -ultra" :
107+ case "spark4.0 -ultra" :
105108 if tokenCount <= 3000000 {
106109 price = float64 (tokenCount ) / 10000 * 0.70
107110 } else if tokenCount <= 15000000 {
@@ -182,7 +185,13 @@ func (p *iFlytekModelProvider) calculatePrice(modelResult *ModelResult) error {
182185}
183186
184187func (p * iFlytekModelProvider ) QueryText (question string , writer io.Writer , history []* RawMessage , prompt string , knowledgeMessages []* RawMessage , agentInfo * AgentInfo ) (* ModelResult , error ) {
185- client := iflytek .NewServer (p .appID , p .apiKey , p .secretKey )
188+ baseUrl , domain , err := p .getBaseUrl ()
189+ _ , client , err := spark .NewClient (spark .WithBaseURL (baseUrl ), spark .WithApiKey (p .apiKey ), spark .WithApiSecret (p .secretKey ), spark .WithAppId (p .appID ), spark .WithAPIDomain (domain ))
190+ if err != nil {
191+ return nil , err
192+ }
193+ ctx := context .Background ()
194+
186195 flusher , ok := writer .(http.Flusher )
187196 if ! ok {
188197 return nil , fmt .Errorf ("writer does not implement http.Flusher" )
@@ -199,19 +208,11 @@ func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, hist
199208 }
200209 }
201210
202- session , err := client .GetSession ("1" )
203- if err != nil {
204- return nil , fmt .Errorf ("iflytek get session error: %v" , err )
205- }
206- if session == nil {
207- return nil , fmt .Errorf ("iflytek get session error: session is nil" )
208- }
211+ chatMessages := p .getChatMessages (question , history )
209212
210- session .Req .Parameter .Chat .Temperature = p .temperature
211- session .Req .Parameter .Chat .TopK = p .topK
212- response , err := session .Send (question )
213- if err != nil {
214- return nil , fmt .Errorf ("iflytek send error: %v" , err )
213+ r := & sparkclient.ChatRequest {
214+ Domain : & domain ,
215+ Messages : chatMessages ,
215216 }
216217
217218 flushData := func (data string ) error {
@@ -222,7 +223,17 @@ func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, hist
222223 return nil
223224 }
224225
225- err = flushData (response )
226+ response := ""
227+
228+ _ , err = client .CreateChatWithCallBack (ctx , r , func (msg messages.ChatMessage ) error {
229+ content := msg .GetContent ()
230+ response += content
231+ err = flushData (content )
232+ if err != nil {
233+ return err
234+ }
235+ return nil
236+ })
226237 if err != nil {
227238 return nil , err
228239 }
@@ -239,3 +250,44 @@ func (p *iFlytekModelProvider) QueryText(question string, writer io.Writer, hist
239250
240251 return modelResult , nil
241252}
253+
254+ func (p * iFlytekModelProvider ) getChatMessages (question string , history []* RawMessage ) []messages.ChatMessage {
255+ var result []messages.ChatMessage
256+
257+ for i := len (history ) - 1 ; i >= 0 ; i -- {
258+ msg := history [i ]
259+ role := "user"
260+ if msg .Author == "AI" {
261+ role = "assistant"
262+ }
263+ result = append (result , & messages.GenericChatMessage {
264+ Role : role ,
265+ Content : msg .Text ,
266+ })
267+ }
268+
269+ result = append (result , & messages.GenericChatMessage {
270+ Role : "user" ,
271+ Content : question ,
272+ })
273+
274+ return result
275+ }
276+
277+ func (p * iFlytekModelProvider ) getBaseUrl () (string , string , error ) {
278+ if p .subType == "spark4.0-ultra" {
279+ return "wss://spark-api.xf-yun.com/v4.0/chat" , "4.0Ultra" , nil
280+ } else if p .subType == "spark-max-32k" {
281+ return "wss://spark-api.xf-yun.com/chat/max-32k" , "max-32k" , nil
282+ } else if p .subType == "spark-max" {
283+ return "wss://spark-api.xf-yun.com/v3.5/chat" , "generalv3.5" , nil
284+ } else if p .subType == "spark-pro-128k" {
285+ return "wss://spark-api.xf-yun.com/chat/pro-128k" , "pro-128k" , nil
286+ } else if p .subType == "spark-pro" {
287+ return "wss://spark-api.xf-yun.com/v3.1/chat" , "generalv3" , nil
288+ } else if p .subType == "spark-lite" {
289+ return "wss://spark-api.xf-yun.com/v1.1/chat" , "lite" , nil
290+ } else {
291+ return "" , "" , fmt .Errorf ("chat model not found" )
292+ }
293+ }
0 commit comments