Skip to content

Commit a8acb5f

Browse files
authored
Add tests (#171)
* test models listing * remove non-needed method * test for .streamFinished * add more error tests * improve stream testing * fix typo
1 parent fd44d36 commit a8acb5f

6 files changed

+23
-5
lines changed

api_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ func TestAPIError(t *testing.T) {
141141
if *apiErr.Code != "invalid_api_key" {
142142
t.Fatalf("Unexpected API error code: %s", *apiErr.Code)
143143
}
144+
if apiErr.Error() == "" {
145+
t.Fatal("Empty error message occured")
146+
}
144147
}
145148

146149
func TestRequestError(t *testing.T) {
@@ -163,6 +166,10 @@ func TestRequestError(t *testing.T) {
163166
if reqErr.StatusCode != 418 {
164167
t.Fatalf("Unexpected request error status code: %d", reqErr.StatusCode)
165168
}
169+
170+
if reqErr.Unwrap() == nil {
171+
t.Fatalf("Empty request error occured")
172+
}
166173
}
167174

168175
// numTokens Returns the number of GPT-3 encoded tokens in the given text.

chat_stream.go

-4
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ func (stream *ChatCompletionStream) Close() {
7878
stream.response.Body.Close()
7979
}
8080

81-
func (stream *ChatCompletionStream) GetResponse() *http.Response {
82-
return stream.response
83-
}
84-
8581
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
8682
// support. It sets whether to stream back partial progress. If set, tokens will be
8783
// sent as data-only server-sent events as they become available, with the

chat_stream_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ func TestCreateChatCompletionStream(t *testing.T) {
116116
if !errors.Is(streamErr, io.EOF) {
117117
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
118118
}
119+
120+
_, streamErr = stream.Recv()
121+
if !errors.Is(streamErr, io.EOF) {
122+
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
123+
}
119124
}
120125

121126
// Helper funcs.

models.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ type ModelsList struct {
4040
// ListModels Lists the currently available models,
4141
// and provides basic information about each model such as the model id and parent.
4242
func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) {
43-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/models"), nil)
43+
req, err := c.requestBuilder.build(ctx, http.MethodGet, c.fullURL("/models"), nil)
4444
if err != nil {
4545
return
4646
}

request_builder_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
140140
if !errors.Is(err, errTestRequestBuilderFailed) {
141141
t.Fatalf("Did not return error when request builder failed: %v", err)
142142
}
143+
144+
_, err = client.ListModels(ctx)
145+
if !errors.Is(err, errTestRequestBuilderFailed) {
146+
t.Fatalf("Did not return error when request builder failed: %v", err)
147+
}
143148
}

stream_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ func TestCreateCompletionStream(t *testing.T) {
9393
if !errors.Is(streamErr, io.EOF) {
9494
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
9595
}
96+
97+
_, streamErr = stream.Recv()
98+
if !errors.Is(streamErr, io.EOF) {
99+
t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr)
100+
}
96101
}
97102

98103
// A "tokenRoundTripper" is a struct that implements the RoundTripper

0 commit comments

Comments
 (0)