Skip to content

Commit a6b35c3

Browse files
authored
Check for Stream parameter usage (#174)
* check for stream:true usage * lint
1 parent 34f3a11 commit a6b35c3

6 files changed

+45
-7
lines changed

chat.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ const (
1414
)
1515

1616
var (
17-
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported")
17+
ErrChatCompletionInvalidModel = errors.New("currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported") //nolint:lll
18+
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
1819
)
1920

2021
type ChatCompletionMessage struct {
@@ -65,8 +66,12 @@ func (c *Client) CreateChatCompletion(
6566
ctx context.Context,
6667
request ChatCompletionRequest,
6768
) (response ChatCompletionResponse, err error) {
68-
model := request.Model
69-
switch model {
69+
if request.Stream {
70+
err = ErrChatCompletionStreamNotSupported
71+
return
72+
}
73+
74+
switch request.Model {
7075
case GPT3Dot5Turbo0301, GPT3Dot5Turbo, GPT4, GPT40314, GPT432K0314, GPT432K:
7176
default:
7277
err = ErrChatCompletionInvalidModel

chat_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ func TestChatCompletionsWrongModel(t *testing.T) {
3838
}
3939
}
4040

41+
func TestChatCompletionsWithStream(t *testing.T) {
42+
config := DefaultConfig("whatever")
43+
config.BaseURL = "http://localhost/v1"
44+
client := NewClientWithConfig(config)
45+
ctx := context.Background()
46+
47+
req := ChatCompletionRequest{
48+
Stream: true,
49+
}
50+
_, err := client.CreateChatCompletion(ctx, req)
51+
if !errors.Is(err, ErrChatCompletionStreamNotSupported) {
52+
t.Fatalf("CreateChatCompletion didn't return ErrChatCompletionStreamNotSupported error")
53+
}
54+
}
55+
4156
// TestCompletions Tests the completions endpoint of the API using the mocked server.
4257
func TestChatCompletions(t *testing.T) {
4358
server := test.NewTestServer()

completion.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import (
77
)
88

99
var (
10-
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
10+
ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll
11+
ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll
1112
)
1213

1314
// GPT3 Defines the models provided by OpenAI to use when generating
@@ -99,6 +100,11 @@ func (c *Client) CreateCompletion(
99100
ctx context.Context,
100101
request CompletionRequest,
101102
) (response CompletionResponse, err error) {
103+
if request.Stream {
104+
err = ErrCompletionStreamNotSupported
105+
return
106+
}
107+
102108
if request.Model == GPT3Dot5Turbo0301 || request.Model == GPT3Dot5Turbo {
103109
err = ErrCompletionUnsupportedModel
104110
return

completion_test.go

+12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ func TestCompletionsWrongModel(t *testing.T) {
3333
}
3434
}
3535

36+
func TestCompletionWithStream(t *testing.T) {
37+
config := DefaultConfig("whatever")
38+
client := NewClientWithConfig(config)
39+
40+
ctx := context.Background()
41+
req := CompletionRequest{Stream: true}
42+
_, err := client.CreateCompletion(ctx, req)
43+
if !errors.Is(err, ErrCompletionStreamNotSupported) {
44+
t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported")
45+
}
46+
}
47+
3648
// TestCompletions Tests the completions endpoint of the API using the mocked server.
3749
func TestCompletions(t *testing.T) {
3850
server := test.NewTestServer()

models_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestListModels(t *testing.T) {
3333
}
3434

3535
// handleModelsEndpoint Handles the models endpoint by the test server.
36-
func handleModelsEndpoint(w http.ResponseWriter, r *http.Request) {
36+
func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
3737
resBytes, _ := json.Marshal(ModelsList{})
3838
fmt.Fprintln(w, string(resBytes))
3939
}

request_builder_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ type (
1919
failingMarshaller struct{}
2020
)
2121

22-
func (*failingMarshaller) marshal(value any) ([]byte, error) {
22+
func (*failingMarshaller) marshal(_ any) ([]byte, error) {
2323
return []byte{}, errTestMarshallerFailed
2424
}
2525

26-
func (*failingRequestBuilder) build(ctx context.Context, method, url string, requset any) (*http.Request, error) {
26+
func (*failingRequestBuilder) build(_ context.Context, _, _ string, _ any) (*http.Request, error) {
2727
return nil, errTestRequestBuilderFailed
2828
}
2929

0 commit comments

Comments
 (0)