Skip to content

Commit 37c6b00

Browse files
authored
feat: support OpenAI model parameters (#641)
* feat: support openai completion parameters * fix: update providerEditPage * fix: update providerEditPage
1 parent bccb50b commit 37c6b00

6 files changed

Lines changed: 154 additions & 12 deletions

File tree

model/openai.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,24 @@ var __maxTokens = map[string]int{
5151
}
5252

5353
type OpenAiModelProvider struct {
54-
subType string
55-
secretKey string
54+
subType string
55+
secretKey string
56+
temperature float32
57+
topP float32
58+
frequencyPenalty float32
59+
presencePenalty float32
5660
}
5761

58-
func NewOpenAiModelProvider(subType string, secretKey string) (*OpenAiModelProvider, error) {
59-
return &OpenAiModelProvider{subType: subType, secretKey: secretKey}, nil
62+
func NewOpenAiModelProvider(subType string, secretKey string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32) (*OpenAiModelProvider, error) {
63+
p := &OpenAiModelProvider{
64+
subType: subType,
65+
secretKey: secretKey,
66+
temperature: temperature,
67+
topP: topP,
68+
frequencyPenalty: frequencyPenalty,
69+
presencePenalty: presencePenalty,
70+
}
71+
return p, nil
6072
}
6173

6274
func getProxyClientFromToken(authToken string) *openai.Client {
@@ -98,14 +110,22 @@ func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, build
98110
}
99111

100112
maxTokens := p.GetMaxTokens() - promptTokens
113+
temperature := p.temperature
114+
topP := p.topP
115+
frequencyPenalty := p.frequencyPenalty
116+
presencePenalty := p.presencePenalty
101117

102118
respStream, err := client.CreateCompletionStream(
103119
ctx,
104120
openai.CompletionRequest{
105-
Model: model,
106-
Prompt: question,
107-
MaxTokens: maxTokens,
108-
Stream: true,
121+
Model: model,
122+
Prompt: question,
123+
MaxTokens: maxTokens,
124+
Stream: true,
125+
Temperature: temperature,
126+
TopP: topP,
127+
FrequencyPenalty: frequencyPenalty,
128+
PresencePenalty: presencePenalty,
109129
},
110130
)
111131
if err != nil {

model/provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ type ModelProvider interface {
2323
QueryText(question string, writer io.Writer, builder *strings.Builder) error
2424
}
2525

26-
func GetModelProvider(typ string, subType string, clientId string, clientSecret string) (ModelProvider, error) {
26+
func GetModelProvider(typ string, subType string, clientId string, clientSecret string, temperature float32, topP float32, frequencyPenalty float32, presencePenalty float32) (ModelProvider, error) {
2727
var p ModelProvider
2828
var err error
2929
if typ == "OpenAI" {
30-
p, err = NewOpenAiModelProvider(subType, clientSecret)
30+
p, err = NewOpenAiModelProvider(subType, clientSecret, temperature, topP, frequencyPenalty, presencePenalty)
3131
} else if typ == "Hugging Face" {
3232
p, err = NewHuggingFaceModelProvider(subType, clientSecret)
3333
} else if typ == "Claude" {

object/provider.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ type Provider struct {
3636
ClientId string `xorm:"varchar(100)" json:"clientId"`
3737
ClientSecret string `xorm:"varchar(2000)" json:"clientSecret"`
3838
ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"`
39+
40+
Temperature float32 `xorm:"float" json:"temperature"`
41+
TopP float32 `xorm:"float" json:"topP"`
42+
FrequencyPenalty float32 `xorm:"float" json:"frequencyPenalty"`
43+
PresencePenalty float32 `xorm:"float" json:"presencePenalty"`
3944
}
4045

4146
func GetMaskedProvider(provider *Provider, isMaskEnabled bool) *Provider {
@@ -205,7 +210,7 @@ func (p *Provider) GetStorageProviderObj() (storage.StorageProvider, error) {
205210
}
206211

207212
func (p *Provider) GetModelProvider() (model.ModelProvider, error) {
208-
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret)
213+
pProvider, err := model.GetModelProvider(p.Type, p.SubType, p.ClientId, p.ClientSecret, p.Temperature, p.TopP, p.FrequencyPenalty, p.PresencePenalty)
209214
if err != nil {
210215
return nil, err
211216
}

web/src/ProviderEditPage.js

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
import React from "react";
16-
import {Button, Card, Col, Input, Row, Select} from "antd";
16+
import {Button, Card, Col, Input, InputNumber, Row, Select, Slider} from "antd";
1717
import * as ProviderBackend from "./backend/ProviderBackend";
1818
import * as Setting from "./Setting";
1919
import i18next from "i18next";
@@ -51,6 +51,8 @@ class ProviderEditPage extends React.Component {
5151
parseProviderField(key, value) {
5252
if ([""].includes(key)) {
5353
value = Setting.myParseInt(value);
54+
} else if (["temperature", "topP", "frequencyPenalty", "presencePenalty"].includes(key)) {
55+
value = Setting.myParseFloat(value);
5456
}
5557
return value;
5658
}
@@ -65,6 +67,44 @@ class ProviderEditPage extends React.Component {
6567
});
6668
}
6769

70+
InputSlider(props) {
71+
const {
72+
min,
73+
max,
74+
step,
75+
value,
76+
onChange,
77+
} = props;
78+
79+
return (
80+
<>
81+
<Col span={2}>
82+
<InputNumber
83+
min={min}
84+
max={max}
85+
step={step}
86+
style={{width: "100%"}}
87+
value={value}
88+
onChange={onChange}
89+
/>
90+
</Col>
91+
<Col span={20}>
92+
<Slider
93+
min={min}
94+
max={max}
95+
step={step}
96+
style={{
97+
marginLeft: "1%",
98+
marginRight: "1%",
99+
}}
100+
value={value}
101+
onChange={onChange}
102+
/>
103+
</Col>
104+
</>
105+
);
106+
}
107+
68108
renderProvider() {
69109
return (
70110
<Card size="small" title={
@@ -171,6 +211,74 @@ class ProviderEditPage extends React.Component {
171211
</Row>
172212
)
173213
}
214+
{
215+
(this.state.provider.category === "Model" && this.state.provider.type === "OpenAI") ? (
216+
<>
217+
<Row style={{marginTop: "20px"}}>
218+
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
219+
{i18next.t("provider:Temperature")}:
220+
</Col>
221+
<this.InputSlider
222+
min={0}
223+
max={2}
224+
step={0.01}
225+
value={this.state.provider.temperature}
226+
onChange={(value) => {
227+
this.updateProviderField("temperature", value);
228+
}}
229+
isMobile={Setting.isMobile()}
230+
/>
231+
</Row>
232+
<Row style={{marginTop: "20px"}}>
233+
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
234+
{i18next.t("provider:Top P")}:
235+
</Col>
236+
<this.InputSlider
237+
min={0}
238+
max={1}
239+
step={0.01}
240+
value={this.state.provider.topP}
241+
onChange={(value) => {
242+
this.updateProviderField("topP", value);
243+
}}
244+
isMobile={Setting.isMobile()}
245+
/>
246+
</Row>
247+
<Row style={{marginTop: "20px"}}>
248+
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
249+
{i18next.t("provider:Frequency penalty")}:
250+
</Col>
251+
<this.InputSlider
252+
label={i18next.t("provider:Frequency penalty")}
253+
min={-2}
254+
max={2}
255+
step={0.01}
256+
value={this.state.provider.frequencyPenalty}
257+
onChange={(value) => {
258+
this.updateProviderField("frequencyPenalty", value);
259+
}}
260+
isMobile={Setting.isMobile()}
261+
/>
262+
</Row>
263+
<Row style={{marginTop: "20px"}}>
264+
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
265+
{i18next.t("provider:Presence penalty")}:
266+
</Col>
267+
<this.InputSlider
268+
label={i18next.t("provider:Presence penalty")}
269+
min={-2}
270+
max={2}
271+
step={0.01}
272+
value={this.state.provider.presencePenalty}
273+
onChange={(value) => {
274+
this.updateProviderField("presencePenalty", value);
275+
}}
276+
isMobile={Setting.isMobile()}
277+
/>
278+
</Row>
279+
</>
280+
) : null
281+
}
174282
<Row style={{marginTop: "20px"}} >
175283
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
176284
{i18next.t("general:Provider URL")}:

web/src/ProviderListPage.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class ProviderListPage extends React.Component {
5858
subType: "text-davinci-003",
5959
clientId: "",
6060
clientSecret: "",
61+
temperature: 1,
62+
topP: 1,
63+
frequencyPenalty: 0,
64+
presencePenalty: 0,
6165
providerUrl: "https://platform.openai.com/account/api-keys",
6266
};
6367
}

web/src/Setting.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ export function myParseInt(i) {
7777
return isNaN(res) ? 0 : res;
7878
}
7979

80+
export function myParseFloat(f) {
81+
const res = parseFloat(f);
82+
return isNaN(res) ? 0.0 : res;
83+
}
84+
8085
export function openLink(link) {
8186
// this.props.history.push(link);
8287
const w = window.open("about:blank");

0 commit comments

Comments
 (0)