Skip to content

Commit 553c2d2

Browse files
authored
feat: support more sub types in Ernie model (#625)
1 parent 078cdba commit 553c2d2

5 files changed

Lines changed: 118 additions & 42 deletions

File tree

ai/ernie.go

Lines changed: 97 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ import (
2626
)
2727

2828
type ErnieModelProvider struct {
29+
subType string
2930
apiKey string
3031
secretKey string
3132
}
3233

33-
func NewErnieModelProvider(apiKey string, secretKey string) (*ErnieModelProvider, error) {
34-
return &ErnieModelProvider{apiKey: apiKey, secretKey: secretKey}, nil
34+
func NewErnieModelProvider(subType string, apiKey string, secretKey string) (*ErnieModelProvider, error) {
35+
return &ErnieModelProvider{subType: subType, apiKey: apiKey, secretKey: secretKey}, nil
3536
}
3637

3738
func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
@@ -42,35 +43,111 @@ func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builde
4243
return fmt.Errorf("writer does not implement http.Flusher")
4344
}
4445

45-
request := ernie.ErnieBotRequest{
46-
Messages: []ernie.ChatCompletionMessage{
47-
{
48-
Role: "user",
49-
Content: question,
50-
},
46+
messages := []ernie.ChatCompletionMessage{
47+
{
48+
Role: "user",
49+
Content: question,
5150
},
52-
Stream: true,
5351
}
54-
stream, err := client.CreateErnieBotChatCompletionStream(ctx, request)
55-
if err != nil {
56-
return err
52+
53+
flushData := func(data string) error {
54+
if _, err := fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil {
55+
return err
56+
}
57+
flusher.Flush()
58+
builder.WriteString(data)
59+
return nil
5760
}
5861

59-
defer stream.Close()
60-
for {
61-
response, err := stream.Recv()
62-
if errors.Is(err, io.EOF) {
63-
return nil
62+
if p.subType == "ERNIE-Bot" {
63+
stream, err := client.CreateErnieBotChatCompletionStream(ctx, ernie.ErnieBotRequest{Messages: messages})
64+
if err != nil {
65+
return err
66+
}
67+
68+
defer stream.Close()
69+
for {
70+
response, err := stream.Recv()
71+
if errors.Is(err, io.EOF) {
72+
return nil
73+
}
74+
75+
if err != nil {
76+
return err
77+
}
78+
79+
err = flushData(response.Result)
80+
if err != nil {
81+
return err
82+
}
83+
}
84+
} else if p.subType == "ERNIE-Bot-turbo" {
85+
stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx, ernie.ErnieBotTurboRequest{Messages: messages})
86+
if err != nil {
87+
return err
6488
}
6589

90+
defer stream.Close()
91+
for {
92+
response, err := stream.Recv()
93+
if errors.Is(err, io.EOF) {
94+
return nil
95+
}
96+
97+
if err != nil {
98+
return err
99+
}
100+
101+
err = flushData(response.Result)
102+
if err != nil {
103+
return err
104+
}
105+
}
106+
} else if p.subType == "BLOOMZ-7B" {
107+
stream, err := client.CreateBloomz7b1ChatCompletionStream(ctx, ernie.Bloomz7b1Request{Messages: messages})
66108
if err != nil {
67109
return err
68110
}
69111

70-
if _, err = fmt.Fprintf(writer, "event: message\ndata: %s\n\n", response.Result); err != nil {
112+
defer stream.Close()
113+
for {
114+
response, err := stream.Recv()
115+
if errors.Is(err, io.EOF) {
116+
return nil
117+
}
118+
119+
if err != nil {
120+
return err
121+
}
122+
123+
err = flushData(response.Result)
124+
if err != nil {
125+
return err
126+
}
127+
}
128+
} else if p.subType == "Llama-2" {
129+
stream, err := client.CreateLlamaChatCompletionStream(ctx, ernie.LlamaChatRequest{Messages: messages})
130+
if err != nil {
71131
return err
72132
}
73-
flusher.Flush()
74-
builder.WriteString(response.Result)
133+
134+
defer stream.Close()
135+
for {
136+
response, err := stream.Recv()
137+
if errors.Is(err, io.EOF) {
138+
return nil
139+
}
140+
141+
if err != nil {
142+
return err
143+
}
144+
145+
err = flushData(response.Result)
146+
if err != nil {
147+
return err
148+
}
149+
}
75150
}
151+
152+
return nil
76153
}

ai/model.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ func GetModelProvider(typ string, subType string, clientId string, clientSecret
3030
p, err = NewOpenAiModelProvider(subType, clientSecret)
3131
} else if typ == "Hugging Face" {
3232
p, err = NewHuggingFaceModelProvider(subType, clientSecret)
33-
} else if typ == "Hugging Face" {
34-
p, err = NewErnieModelProvider(clientId, clientSecret)
33+
} else if typ == "Ernie" {
34+
p, err = NewErnieModelProvider(subType, clientId, clientSecret)
3535
}
3636

3737
if err != nil {

controllers/message.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func getModelProviderFromContext(owner string, name string) (*object.Provider, e
8787
return nil, err
8888
}
8989

90-
if store.ModelProvider != "" {
90+
if store != nil && store.ModelProvider != "" {
9191
providerName = store.ModelProvider
9292
}
9393
}

web/src/ProviderEditPage.js

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,20 @@ class ProviderEditPage extends React.Component {
125125
</Select>
126126
</Col>
127127
</Row>
128-
{
129-
this.state.provider.type === "Ernie" ? null : (
130-
<Row style={{marginTop: "20px"}} >
131-
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
132-
{i18next.t("provider:Sub type")}:
133-
</Col>
134-
<Col span={22} >
135-
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}>
136-
{
137-
Setting.getProviderSubTypeOptions(this.state.provider.type)
138-
.sort((a, b) => a.name.localeCompare(b.name))
139-
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
140-
}
141-
</Select>
142-
</Col>
143-
</Row>
144-
)
145-
}
128+
<Row style={{marginTop: "20px"}} >
129+
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
130+
{i18next.t("provider:Sub type")}:
131+
</Col>
132+
<Col span={22} >
133+
<Select virtual={false} style={{width: "100%"}} value={this.state.provider.subType} onChange={(value => {this.updateProviderField("subType", value);})}>
134+
{
135+
Setting.getProviderSubTypeOptions(this.state.provider.type)
136+
.sort((a, b) => a.name.localeCompare(b.name))
137+
.map((item, index) => <Option key={index} value={item.id}>{item.name}</Option>)
138+
}
139+
</Select>
140+
</Col>
141+
</Row>
146142
{
147143
this.state.provider.type !== "Ernie" ? null : (
148144
<Row style={{marginTop: "20px"}} >

web/src/Setting.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,10 @@ export function getProviderSubTypeOptions(type) {
696696
} else if (type === "Ernie") {
697697
return (
698698
[
699-
{id: "Default", name: "Default"},
699+
{id: "ERNIE-Bot", name: "ERNIE-Bot"},
700+
{id: "ERNIE-Bot-turbo", name: "ERNIE-Bot-turbo"},
701+
{id: "BLOOMZ-7B", name: "BLOOMZ-7B"},
702+
{id: "Llama-2", name: "Llama-2"},
700703
]
701704
);
702705
} else {

0 commit comments

Comments
 (0)