Skip to content

Commit 6830e00

Browse files
authored
Support Retrieve model API (#340) (#341)
* Support Retrieve model API (#340) * Test for GetModel error cases. (#340) * Reduce the cognitive complexity of TestClientReturnsRequestBuilderErrors (#340)
1 parent 1394329 commit 6830e00

File tree

3 files changed

+127
-94
lines changed

3 files changed

+127
-94
lines changed

client_test.go

Lines changed: 72 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -170,104 +170,82 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
170170

171171
ctx := context.Background()
172172

173-
_, err = client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
174-
if !errors.Is(err, errTestRequestBuilderFailed) {
175-
t.Fatalf("Did not return error when request builder failed: %v", err)
176-
}
177-
178-
_, err = client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
179-
if !errors.Is(err, errTestRequestBuilderFailed) {
180-
t.Fatalf("Did not return error when request builder failed: %v", err)
181-
}
182-
183-
_, err = client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
184-
if !errors.Is(err, errTestRequestBuilderFailed) {
185-
t.Fatalf("Did not return error when request builder failed: %v", err)
186-
}
187-
188-
_, err = client.CreateFineTune(ctx, FineTuneRequest{})
189-
if !errors.Is(err, errTestRequestBuilderFailed) {
190-
t.Fatalf("Did not return error when request builder failed: %v", err)
191-
}
192-
193-
_, err = client.ListFineTunes(ctx)
194-
if !errors.Is(err, errTestRequestBuilderFailed) {
195-
t.Fatalf("Did not return error when request builder failed: %v", err)
196-
}
197-
198-
_, err = client.CancelFineTune(ctx, "")
199-
if !errors.Is(err, errTestRequestBuilderFailed) {
200-
t.Fatalf("Did not return error when request builder failed: %v", err)
201-
}
202-
203-
_, err = client.GetFineTune(ctx, "")
204-
if !errors.Is(err, errTestRequestBuilderFailed) {
205-
t.Fatalf("Did not return error when request builder failed: %v", err)
206-
}
207-
208-
_, err = client.DeleteFineTune(ctx, "")
209-
if !errors.Is(err, errTestRequestBuilderFailed) {
210-
t.Fatalf("Did not return error when request builder failed: %v", err)
211-
}
212-
213-
_, err = client.ListFineTuneEvents(ctx, "")
214-
if !errors.Is(err, errTestRequestBuilderFailed) {
215-
t.Fatalf("Did not return error when request builder failed: %v", err)
216-
}
217-
218-
_, err = client.Moderations(ctx, ModerationRequest{})
219-
if !errors.Is(err, errTestRequestBuilderFailed) {
220-
t.Fatalf("Did not return error when request builder failed: %v", err)
221-
}
222-
223-
_, err = client.Edits(ctx, EditsRequest{})
224-
if !errors.Is(err, errTestRequestBuilderFailed) {
225-
t.Fatalf("Did not return error when request builder failed: %v", err)
226-
}
227-
228-
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
229-
if !errors.Is(err, errTestRequestBuilderFailed) {
230-
t.Fatalf("Did not return error when request builder failed: %v", err)
173+
type TestCase struct {
174+
Name string
175+
TestFunc func() (any, error)
231176
}
232177

233-
_, err = client.CreateImage(ctx, ImageRequest{})
234-
if !errors.Is(err, errTestRequestBuilderFailed) {
235-
t.Fatalf("Did not return error when request builder failed: %v", err)
236-
}
237-
238-
err = client.DeleteFile(ctx, "")
239-
if !errors.Is(err, errTestRequestBuilderFailed) {
240-
t.Fatalf("Did not return error when request builder failed: %v", err)
241-
}
242-
243-
_, err = client.GetFile(ctx, "")
244-
if !errors.Is(err, errTestRequestBuilderFailed) {
245-
t.Fatalf("Did not return error when request builder failed: %v", err)
178+
testCases := []TestCase{
179+
{"CreateCompletion", func() (any, error) {
180+
return client.CreateCompletion(ctx, CompletionRequest{Prompt: "testing"})
181+
}},
182+
{"CreateCompletionStream", func() (any, error) {
183+
return client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
184+
}},
185+
{"CreateChatCompletion", func() (any, error) {
186+
return client.CreateChatCompletion(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
187+
}},
188+
{"CreateChatCompletionStream", func() (any, error) {
189+
return client.CreateChatCompletionStream(ctx, ChatCompletionRequest{Model: GPT3Dot5Turbo})
190+
}},
191+
{"CreateFineTune", func() (any, error) {
192+
return client.CreateFineTune(ctx, FineTuneRequest{})
193+
}},
194+
{"ListFineTunes", func() (any, error) {
195+
return client.ListFineTunes(ctx)
196+
}},
197+
{"CancelFineTune", func() (any, error) {
198+
return client.CancelFineTune(ctx, "")
199+
}},
200+
{"GetFineTune", func() (any, error) {
201+
return client.GetFineTune(ctx, "")
202+
}},
203+
{"DeleteFineTune", func() (any, error) {
204+
return client.DeleteFineTune(ctx, "")
205+
}},
206+
{"ListFineTuneEvents", func() (any, error) {
207+
return client.ListFineTuneEvents(ctx, "")
208+
}},
209+
{"Moderations", func() (any, error) {
210+
return client.Moderations(ctx, ModerationRequest{})
211+
}},
212+
{"Edits", func() (any, error) {
213+
return client.Edits(ctx, EditsRequest{})
214+
}},
215+
{"CreateEmbeddings", func() (any, error) {
216+
return client.CreateEmbeddings(ctx, EmbeddingRequest{})
217+
}},
218+
{"CreateImage", func() (any, error) {
219+
return client.CreateImage(ctx, ImageRequest{})
220+
}},
221+
{"DeleteFile", func() (any, error) {
222+
return nil, client.DeleteFile(ctx, "")
223+
}},
224+
{"GetFile", func() (any, error) {
225+
return client.GetFile(ctx, "")
226+
}},
227+
{"ListFiles", func() (any, error) {
228+
return client.ListFiles(ctx)
229+
}},
230+
{"ListEngines", func() (any, error) {
231+
return client.ListEngines(ctx)
232+
}},
233+
{"GetEngine", func() (any, error) {
234+
return client.GetEngine(ctx, "")
235+
}},
236+
{"ListModels", func() (any, error) {
237+
return client.ListModels(ctx)
238+
}},
239+
{"GetModel", func() (any, error) {
240+
return client.GetModel(ctx, "text-davinci-003")
241+
}},
246242
}
247243

248-
_, err = client.ListFiles(ctx)
249-
if !errors.Is(err, errTestRequestBuilderFailed) {
250-
t.Fatalf("Did not return error when request builder failed: %v", err)
251-
}
252-
253-
_, err = client.ListEngines(ctx)
254-
if !errors.Is(err, errTestRequestBuilderFailed) {
255-
t.Fatalf("Did not return error when request builder failed: %v", err)
256-
}
257-
258-
_, err = client.GetEngine(ctx, "")
259-
if !errors.Is(err, errTestRequestBuilderFailed) {
260-
t.Fatalf("Did not return error when request builder failed: %v", err)
261-
}
262-
263-
_, err = client.ListModels(ctx)
264-
if !errors.Is(err, errTestRequestBuilderFailed) {
265-
t.Fatalf("Did not return error when request builder failed: %v", err)
266-
}
267-
268-
_, err = client.CreateCompletionStream(ctx, CompletionRequest{Prompt: ""})
269-
if !errors.Is(err, errTestRequestBuilderFailed) {
270-
t.Fatalf("Did not return error when request builder failed: %v", err)
244+
for _, testCase := range testCases {
245+
_, err = testCase.TestFunc()
246+
if !errors.Is(err, errTestRequestBuilderFailed) {
247+
t.Fatalf("%s did not return error when request builder failed: %v", testCase.Name, err)
248+
}
271249
}
272250
}
273251

models.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
)
78

@@ -48,3 +49,16 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error)
4849
err = c.sendRequest(req, &models)
4950
return
5051
}
52+
53+
// GetModel Retrieves a model instance, providing basic information about
54+
// the model such as the owner and permissioning.
55+
func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) {
56+
urlSuffix := fmt.Sprintf("/models/%s", modelID)
57+
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
58+
if err != nil {
59+
return
60+
}
61+
62+
err = c.sendRequest(req, &model)
63+
return
64+
}

models_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,44 @@ func handleModelsEndpoint(w http.ResponseWriter, _ *http.Request) {
5454
resBytes, _ := json.Marshal(ModelsList{})
5555
fmt.Fprintln(w, string(resBytes))
5656
}
57+
58+
// TestGetModel Tests the retrieve model endpoint of the API using the mocked server.
59+
func TestGetModel(t *testing.T) {
60+
server := test.NewTestServer()
61+
server.RegisterHandler("/v1/models/text-davinci-003", handleGetModelEndpoint)
62+
// create the test server
63+
ts := server.OpenAITestServer()
64+
ts.Start()
65+
defer ts.Close()
66+
67+
config := DefaultConfig(test.GetTestToken())
68+
config.BaseURL = ts.URL + "/v1"
69+
client := NewClientWithConfig(config)
70+
ctx := context.Background()
71+
72+
_, err := client.GetModel(ctx, "text-davinci-003")
73+
checks.NoError(t, err, "GetModel error")
74+
}
75+
76+
func TestAzureGetModel(t *testing.T) {
77+
server := test.NewTestServer()
78+
server.RegisterHandler("/openai/models/text-davinci-003", handleModelsEndpoint)
79+
// create the test server
80+
ts := server.OpenAITestServer()
81+
ts.Start()
82+
defer ts.Close()
83+
84+
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
85+
config.BaseURL = ts.URL
86+
client := NewClientWithConfig(config)
87+
ctx := context.Background()
88+
89+
_, err := client.GetModel(ctx, "text-davinci-003")
90+
checks.NoError(t, err, "GetModel error")
91+
}
92+
93+
// handleModelsEndpoint Handles the models endpoint by the test server.
94+
func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) {
95+
resBytes, _ := json.Marshal(Model{})
96+
fmt.Fprintln(w, string(resBytes))
97+
}

0 commit comments

Comments
 (0)