@@ -26,12 +26,13 @@ import (
2626)
2727
2828type ErnieModelProvider struct {
29+ subType string
2930 apiKey string
3031 secretKey string
3132}
3233
33- func NewErnieModelProvider (apiKey string , secretKey string ) (* ErnieModelProvider , error ) {
34- return & ErnieModelProvider {apiKey : apiKey , secretKey : secretKey }, nil
34+ func NewErnieModelProvider (subType string , apiKey string , secretKey string ) (* ErnieModelProvider , error ) {
35+ return & ErnieModelProvider {subType : subType , apiKey : apiKey , secretKey : secretKey }, nil
3536}
3637
3738func (p * ErnieModelProvider ) QueryText (question string , writer io.Writer , builder * strings.Builder ) error {
@@ -42,35 +43,111 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde
4243 return fmt .Errorf ("writer does not implement http.Flusher" )
4344 }
4445
45- request := ernie.ErnieBotRequest {
46- Messages : []ernie.ChatCompletionMessage {
47- {
48- Role : "user" ,
49- Content : question ,
50- },
46+ messages := []ernie.ChatCompletionMessage {
47+ {
48+ Role : "user" ,
49+ Content : question ,
5150 },
52- Stream : true ,
5351 }
54- stream , err := client .CreateErnieBotChatCompletionStream (ctx , request )
55- if err != nil {
56- return err
52+
53+ flushData := func (data string ) error {
54+ if _ , err := fmt .Fprintf (writer , "event: message\n data: %s\n \n " , data ); err != nil {
55+ return err
56+ }
57+ flusher .Flush ()
58+ builder .WriteString (data )
59+ return nil
5760 }
5861
59- defer stream .Close ()
60- for {
61- response , err := stream .Recv ()
62- if errors .Is (err , io .EOF ) {
63- return nil
62+ if p .subType == "ERNIE-Bot" {
63+ stream , err := client .CreateErnieBotChatCompletionStream (ctx , ernie.ErnieBotRequest {Messages : messages })
64+ if err != nil {
65+ return err
66+ }
67+
68+ defer stream .Close ()
69+ for {
70+ response , err := stream .Recv ()
71+ if errors .Is (err , io .EOF ) {
72+ return nil
73+ }
74+
75+ if err != nil {
76+ return err
77+ }
78+
79+ err = flushData (response .Result )
80+ if err != nil {
81+ return err
82+ }
83+ }
84+ } else if p .subType == "ERNIE-Bot-turbo" {
85+ stream , err := client .CreateErnieBotTurboChatCompletionStream (ctx , ernie.ErnieBotTurboRequest {Messages : messages })
86+ if err != nil {
87+ return err
6488 }
6589
90+ defer stream .Close ()
91+ for {
92+ response , err := stream .Recv ()
93+ if errors .Is (err , io .EOF ) {
94+ return nil
95+ }
96+
97+ if err != nil {
98+ return err
99+ }
100+
101+ err = flushData (response .Result )
102+ if err != nil {
103+ return err
104+ }
105+ }
106+ } else if p .subType == "BLOOMZ-7B" {
107+ stream , err := client .CreateBloomz7b1ChatCompletionStream (ctx , ernie.Bloomz7b1Request {Messages : messages })
66108 if err != nil {
67109 return err
68110 }
69111
70- if _ , err = fmt .Fprintf (writer , "event: message\n data: %s\n \n " , response .Result ); err != nil {
112+ defer stream .Close ()
113+ for {
114+ response , err := stream .Recv ()
115+ if errors .Is (err , io .EOF ) {
116+ return nil
117+ }
118+
119+ if err != nil {
120+ return err
121+ }
122+
123+ err = flushData (response .Result )
124+ if err != nil {
125+ return err
126+ }
127+ }
128+ } else if p .subType == "Llama-2" {
129+ stream , err := client .CreateLlamaChatCompletionStream (ctx , ernie.LlamaChatRequest {Messages : messages })
130+ if err != nil {
71131 return err
72132 }
73- flusher .Flush ()
74- builder .WriteString (response .Result )
133+
134+ defer stream .Close ()
135+ for {
136+ response , err := stream .Recv ()
137+ if errors .Is (err , io .EOF ) {
138+ return nil
139+ }
140+
141+ if err != nil {
142+ return err
143+ }
144+
145+ err = flushData (response .Result )
146+ if err != nil {
147+ return err
148+ }
149+ }
75150 }
151+
152+ return nil
76153}
0 commit comments