Skip to content

Commit 1572e75

Browse files
authored
fix: add preset max tokens for OpenAI model provider (#637)
1 parent 13d1659 commit 1572e75

1 file changed

Lines changed: 35 additions & 2 deletions

File tree

model/openai.go

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@ import (
2525
"github.com/sashabaranov/go-openai"
2626
)
2727

28+
// https://pkg.go.dev/github.com/sashabaranov/go-openai@v1.12.0#pkg-constants
29+
// https://platform.openai.com/docs/models/overview
30+
var __maxTokens = map[string]int{
31+
openai.GPT4: 8192,
32+
openai.GPT40613: 8192,
33+
openai.GPT432K: 32768,
34+
openai.GPT432K0613: 32768,
35+
openai.GPT40314: 8192,
36+
openai.GPT432K0314: 32768,
37+
openai.GPT3Dot5Turbo: 4097,
38+
openai.GPT3Dot5Turbo16K: 16385,
39+
openai.GPT3Dot5Turbo0613: 4097,
40+
openai.GPT3Dot5Turbo16K0613: 16385,
41+
openai.GPT3Dot5Turbo0301: 4097,
42+
openai.GPT3TextDavinci003: 4097,
43+
openai.GPT3TextDavinci002: 4097,
44+
openai.GPT3TextCurie001: 2049,
45+
openai.GPT3TextBabbage001: 2049,
46+
openai.GPT3TextAda001: 2049,
47+
openai.GPT3Davinci: 2049,
48+
openai.GPT3Curie: 2049,
49+
openai.GPT3Ada: 2049,
50+
openai.GPT3Babbage: 2049,
51+
}
52+
2853
type OpenAiModelProvider struct {
2954
subType string
3055
secretKey string
@@ -42,6 +67,15 @@ func getProxyClientFromToken(authToken string) *openai.Client {
4267
return c
4368
}
4469

70+
// GetMaxTokens returns the max tokens for a given openai model.
71+
func (p *OpenAiModelProvider) GetMaxTokens() int {
72+
res, ok := __maxTokens[p.subType]
73+
if !ok {
74+
return 4097
75+
}
76+
return res
77+
}
78+
4579
func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
4680
client := getProxyClientFromToken(p.secretKey)
4781

@@ -63,8 +97,7 @@ func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, build
6397
return err
6498
}
6599

66-
// https://platform.openai.com/docs/models/gpt-3-5
67-
maxTokens := 4097 - promptTokens
100+
maxTokens := p.GetMaxTokens() - promptTokens
68101

69102
respStream, err := client.CreateCompletionStream(
70103
ctx,

0 commit comments

Comments
 (0)