Skip to content

Commit 9add1c3

Browse files
authored
add max_completions_tokens for o1 series models (#857)
* add max_completions_tokens for o1 series models * add validation for o1 series models validataion + beta limitations
1 parent 1ec8c24 commit 9add1c3

File tree

5 files changed

+341
-12
lines changed

5 files changed

+341
-12
lines changed

chat.go

+23-12
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,25 @@ type ChatCompletionResponseFormatJSONSchema struct {
200200

201201
// ChatCompletionRequest represents a request structure for chat completion API.
202202
type ChatCompletionRequest struct {
203-
Model string `json:"model"`
204-
Messages []ChatCompletionMessage `json:"messages"`
205-
MaxTokens int `json:"max_tokens,omitempty"`
206-
Temperature float32 `json:"temperature,omitempty"`
207-
TopP float32 `json:"top_p,omitempty"`
208-
N int `json:"n,omitempty"`
209-
Stream bool `json:"stream,omitempty"`
210-
Stop []string `json:"stop,omitempty"`
211-
PresencePenalty float32 `json:"presence_penalty,omitempty"`
212-
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
213-
Seed *int `json:"seed,omitempty"`
214-
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
203+
Model string `json:"model"`
204+
Messages []ChatCompletionMessage `json:"messages"`
205+
// MaxTokens The maximum number of tokens that can be generated in the chat completion.
206+
// This value can be used to control costs for text generated via API.
207+
// This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models.
208+
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
209+
MaxTokens int `json:"max_tokens,omitempty"`
210+
// MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion,
211+
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning
212+
MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"`
213+
Temperature float32 `json:"temperature,omitempty"`
214+
TopP float32 `json:"top_p,omitempty"`
215+
N int `json:"n,omitempty"`
216+
Stream bool `json:"stream,omitempty"`
217+
Stop []string `json:"stop,omitempty"`
218+
PresencePenalty float32 `json:"presence_penalty,omitempty"`
219+
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
220+
Seed *int `json:"seed,omitempty"`
221+
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
215222
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
216223
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
217224
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
@@ -364,6 +371,10 @@ func (c *Client) CreateChatCompletion(
364371
return
365372
}
366373

374+
if err = validateRequestForO1Models(request); err != nil {
375+
return
376+
}
377+
367378
req, err := c.newRequest(
368379
ctx,
369380
http.MethodPost,

chat_stream.go

+4
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ func (c *Client) CreateChatCompletionStream(
6060
}
6161

6262
request.Stream = true
63+
if err = validateRequestForO1Models(request); err != nil {
64+
return
65+
}
66+
6367
req, err := c.newRequest(
6468
ctx,
6569
http.MethodPost,

chat_stream_test.go

+21
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) {
3636
}
3737
}
3838

39+
func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) {
40+
config := openai.DefaultConfig("whatever")
41+
config.BaseURL = "http://localhost/v1/chat/completions"
42+
client := openai.NewClientWithConfig(config)
43+
ctx := context.Background()
44+
45+
req := openai.ChatCompletionRequest{
46+
Model: openai.O1Preview,
47+
Messages: []openai.ChatCompletionMessage{
48+
{
49+
Role: openai.ChatMessageRoleUser,
50+
Content: "Hello!",
51+
},
52+
},
53+
}
54+
_, err := client.CreateChatCompletionStream(ctx, req)
55+
if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) {
56+
t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err)
57+
}
58+
}
59+
3960
func TestCreateChatCompletionStream(t *testing.T) {
4061
client, server, teardown := setupOpenAITestServer()
4162
defer teardown()

chat_test.go

+211
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,199 @@ func TestChatCompletionsWrongModel(t *testing.T) {
5252
checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg)
5353
}
5454

55+
func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) {
56+
tests := []struct {
57+
name string
58+
in openai.ChatCompletionRequest
59+
expectedError error
60+
}{
61+
{
62+
name: "o1-preview_MaxTokens_deprecated",
63+
in: openai.ChatCompletionRequest{
64+
MaxTokens: 5,
65+
Model: openai.O1Preview,
66+
},
67+
expectedError: openai.ErrO1MaxTokensDeprecated,
68+
},
69+
{
70+
name: "o1-mini_MaxTokens_deprecated",
71+
in: openai.ChatCompletionRequest{
72+
MaxTokens: 5,
73+
Model: openai.O1Mini,
74+
},
75+
expectedError: openai.ErrO1MaxTokensDeprecated,
76+
},
77+
}
78+
79+
for _, tt := range tests {
80+
t.Run(tt.name, func(t *testing.T) {
81+
config := openai.DefaultConfig("whatever")
82+
config.BaseURL = "http://localhost/v1"
83+
client := openai.NewClientWithConfig(config)
84+
ctx := context.Background()
85+
86+
_, err := client.CreateChatCompletion(ctx, tt.in)
87+
checks.HasError(t, err)
88+
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
89+
checks.ErrorIs(t, err, tt.expectedError, msg)
90+
})
91+
}
92+
}
93+
94+
func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
95+
tests := []struct {
96+
name string
97+
in openai.ChatCompletionRequest
98+
expectedError error
99+
}{
100+
{
101+
name: "log_probs_unsupported",
102+
in: openai.ChatCompletionRequest{
103+
MaxCompletionsTokens: 1000,
104+
LogProbs: true,
105+
Model: openai.O1Preview,
106+
},
107+
expectedError: openai.ErrO1BetaLimitationsLogprobs,
108+
},
109+
{
110+
name: "message_type_unsupported",
111+
in: openai.ChatCompletionRequest{
112+
MaxCompletionsTokens: 1000,
113+
Model: openai.O1Mini,
114+
Messages: []openai.ChatCompletionMessage{
115+
{
116+
Role: openai.ChatMessageRoleSystem,
117+
},
118+
},
119+
},
120+
expectedError: openai.ErrO1BetaLimitationsMessageTypes,
121+
},
122+
{
123+
name: "tool_unsupported",
124+
in: openai.ChatCompletionRequest{
125+
MaxCompletionsTokens: 1000,
126+
Model: openai.O1Mini,
127+
Messages: []openai.ChatCompletionMessage{
128+
{
129+
Role: openai.ChatMessageRoleUser,
130+
},
131+
{
132+
Role: openai.ChatMessageRoleAssistant,
133+
},
134+
},
135+
Tools: []openai.Tool{
136+
{
137+
Type: openai.ToolTypeFunction,
138+
},
139+
},
140+
},
141+
expectedError: openai.ErrO1BetaLimitationsTools,
142+
},
143+
{
144+
name: "set_temperature_unsupported",
145+
in: openai.ChatCompletionRequest{
146+
MaxCompletionsTokens: 1000,
147+
Model: openai.O1Mini,
148+
Messages: []openai.ChatCompletionMessage{
149+
{
150+
Role: openai.ChatMessageRoleUser,
151+
},
152+
{
153+
Role: openai.ChatMessageRoleAssistant,
154+
},
155+
},
156+
Temperature: float32(2),
157+
},
158+
expectedError: openai.ErrO1BetaLimitationsOther,
159+
},
160+
{
161+
name: "set_top_unsupported",
162+
in: openai.ChatCompletionRequest{
163+
MaxCompletionsTokens: 1000,
164+
Model: openai.O1Mini,
165+
Messages: []openai.ChatCompletionMessage{
166+
{
167+
Role: openai.ChatMessageRoleUser,
168+
},
169+
{
170+
Role: openai.ChatMessageRoleAssistant,
171+
},
172+
},
173+
Temperature: float32(1),
174+
TopP: float32(0.1),
175+
},
176+
expectedError: openai.ErrO1BetaLimitationsOther,
177+
},
178+
{
179+
name: "set_n_unsupported",
180+
in: openai.ChatCompletionRequest{
181+
MaxCompletionsTokens: 1000,
182+
Model: openai.O1Mini,
183+
Messages: []openai.ChatCompletionMessage{
184+
{
185+
Role: openai.ChatMessageRoleUser,
186+
},
187+
{
188+
Role: openai.ChatMessageRoleAssistant,
189+
},
190+
},
191+
Temperature: float32(1),
192+
TopP: float32(1),
193+
N: 2,
194+
},
195+
expectedError: openai.ErrO1BetaLimitationsOther,
196+
},
197+
{
198+
name: "set_presence_penalty_unsupported",
199+
in: openai.ChatCompletionRequest{
200+
MaxCompletionsTokens: 1000,
201+
Model: openai.O1Mini,
202+
Messages: []openai.ChatCompletionMessage{
203+
{
204+
Role: openai.ChatMessageRoleUser,
205+
},
206+
{
207+
Role: openai.ChatMessageRoleAssistant,
208+
},
209+
},
210+
PresencePenalty: float32(1),
211+
},
212+
expectedError: openai.ErrO1BetaLimitationsOther,
213+
},
214+
{
215+
name: "set_frequency_penalty_unsupported",
216+
in: openai.ChatCompletionRequest{
217+
MaxCompletionsTokens: 1000,
218+
Model: openai.O1Mini,
219+
Messages: []openai.ChatCompletionMessage{
220+
{
221+
Role: openai.ChatMessageRoleUser,
222+
},
223+
{
224+
Role: openai.ChatMessageRoleAssistant,
225+
},
226+
},
227+
FrequencyPenalty: float32(0.1),
228+
},
229+
expectedError: openai.ErrO1BetaLimitationsOther,
230+
},
231+
}
232+
233+
for _, tt := range tests {
234+
t.Run(tt.name, func(t *testing.T) {
235+
config := openai.DefaultConfig("whatever")
236+
config.BaseURL = "http://localhost/v1"
237+
client := openai.NewClientWithConfig(config)
238+
ctx := context.Background()
239+
240+
_, err := client.CreateChatCompletion(ctx, tt.in)
241+
checks.HasError(t, err)
242+
msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
243+
checks.ErrorIs(t, err, tt.expectedError, msg)
244+
})
245+
}
246+
}
247+
55248
func TestChatRequestOmitEmpty(t *testing.T) {
56249
data, err := json.Marshal(openai.ChatCompletionRequest{
57250
// We set model b/c it's required, so omitempty doesn't make sense
@@ -97,6 +290,24 @@ func TestChatCompletions(t *testing.T) {
97290
checks.NoError(t, err, "CreateChatCompletion error")
98291
}
99292

293+
// TestCompletions Tests the completions endpoint of the API using the mocked server.
294+
func TestO1ModelChatCompletions(t *testing.T) {
295+
client, server, teardown := setupOpenAITestServer()
296+
defer teardown()
297+
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
298+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
299+
Model: openai.O1Preview,
300+
MaxCompletionsTokens: 1000,
301+
Messages: []openai.ChatCompletionMessage{
302+
{
303+
Role: openai.ChatMessageRoleUser,
304+
Content: "Hello!",
305+
},
306+
},
307+
})
308+
checks.NoError(t, err, "CreateChatCompletion error")
309+
}
310+
100311
// TestCompletions Tests the completions endpoint of the API using the mocked server.
101312
func TestChatCompletionsWithHeaders(t *testing.T) {
102313
client, server, teardown := setupOpenAITestServer()

0 commit comments

Comments
 (0)