Skip to content

Commit be253c2

Browse files
authored
change azure engine config to modelMapper (#306)
* change azure engine config to azure modelMapper config * Update go.mod * Revert "Update go.mod" This reverts commit 78d14c5. * lint fix * add test * lint fix * lint fix * lint fix * opt * opt * opt * opt
1 parent 5f4ff3e commit be253c2

14 files changed

+119
-32
lines changed

api_internal_test.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) {
9494
az.OrgID = c.OrgID
9595

9696
cli := NewClientWithConfig(az)
97-
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil)
97+
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "")
9898
if err != nil {
9999
t.Errorf("Failed to create request: %v", err)
100100
}
@@ -109,14 +109,16 @@ func TestRequestAuthHeader(t *testing.T) {
109109

110110
func TestAzureFullURL(t *testing.T) {
111111
cases := []struct {
112-
Name string
113-
BaseURL string
114-
Engine string
115-
Expect string
112+
Name string
113+
BaseURL string
114+
AzureModelMapper map[string]string
115+
Model string
116+
Expect string
116117
}{
117118
{
118119
"AzureBaseURLWithSlashAutoStrip",
119120
"https://httpbin.org/",
121+
nil,
120122
"chatgpt-demo",
121123
"https://httpbin.org/" +
122124
"openai/deployments/chatgpt-demo" +
@@ -125,6 +127,7 @@ func TestAzureFullURL(t *testing.T) {
125127
{
126128
"AzureBaseURLWithoutSlashOK",
127129
"https://httpbin.org",
130+
nil,
128131
"chatgpt-demo",
129132
"https://httpbin.org/" +
130133
"openai/deployments/chatgpt-demo" +
@@ -134,10 +137,10 @@ func TestAzureFullURL(t *testing.T) {
134137

135138
for _, c := range cases {
136139
t.Run(c.Name, func(t *testing.T) {
137-
az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine)
140+
az := DefaultAzureConfig("dummy", c.BaseURL)
138141
cli := NewClientWithConfig(az)
139142
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
140-
actual := cli.fullURL("/chat/completions")
143+
actual := cli.fullURL("/chat/completions", c.Model)
141144
if actual != c.Expect {
142145
t.Errorf("Expected %s, got %s", c.Expect, actual)
143146
}

audio.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (c *Client) callAudioAPI(
6868
}
6969

7070
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
71-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody)
71+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody)
7272
if err != nil {
7373
return AudioResponse{}, err
7474
}

chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (c *Client) CreateChatCompletion(
7777
return
7878
}
7979

80-
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
80+
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
8181
if err != nil {
8282
return
8383
}

chat_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func (c *Client) CreateChatCompletionStream(
4646
}
4747

4848
request.Stream = true
49-
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
49+
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
5050
if err != nil {
5151
return
5252
}

client.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ func decodeString(body io.Reader, output *string) error {
9898
return nil
9999
}
100100

101-
func (c *Client) fullURL(suffix string) string {
102-
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
101+
// fullURL returns full URL for request.
102+
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
103+
func (c *Client) fullURL(suffix string, args ...any) string {
104+
// /openai/deployments/{model}/chat/completions?api-version={api_version}
103105
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
104106
baseURL := c.config.BaseURL
105107
baseURL = strings.TrimRight(baseURL, "/")
@@ -108,8 +110,17 @@ func (c *Client) fullURL(suffix string) string {
108110
if strings.Contains(suffix, "/models") {
109111
return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion)
110112
}
113+
azureDeploymentName := "UNKNOWN"
114+
if len(args) > 0 {
115+
model, ok := args[0].(string)
116+
if ok {
117+
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
118+
}
119+
}
111120
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
112-
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion)
121+
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
122+
azureDeploymentName, suffix, c.config.APIVersion,
123+
)
113124
}
114125

115126
// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
@@ -120,8 +131,9 @@ func (c *Client) newStreamRequest(
120131
ctx context.Context,
121132
method string,
122133
urlSuffix string,
123-
body any) (*http.Request, error) {
124-
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix), body)
134+
body any,
135+
model string) (*http.Request, error) {
136+
req, err := c.requestBuilder.build(ctx, method, c.fullURL(urlSuffix, model), body)
125137
if err != nil {
126138
return nil, err
127139
}

completion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func (c *Client) CreateCompletion(
155155
return
156156
}
157157

158-
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix), request)
158+
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request)
159159
if err != nil {
160160
return
161161
}

config.go

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai
22

33
import (
44
"net/http"
5+
"regexp"
56
)
67

78
const (
@@ -26,13 +27,12 @@ const AzureAPIKeyHeader = "api-key"
2627
type ClientConfig struct {
2728
authToken string
2829

29-
BaseURL string
30-
OrgID string
31-
APIType APIType
32-
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
33-
Engine string // required when APIType is APITypeAzure or APITypeAzureAD
34-
35-
HTTPClient *http.Client
30+
BaseURL string
31+
OrgID string
32+
APIType APIType
33+
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
34+
AzureModelMapperFunc func(model string) string // replace model to azure deployment name func
35+
HTTPClient *http.Client
3636

3737
EmptyMessagesLimit uint
3838
}
@@ -50,14 +50,16 @@ func DefaultConfig(authToken string) ClientConfig {
5050
}
5151
}
5252

53-
func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
53+
func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
5454
return ClientConfig{
5555
authToken: apiKey,
5656
BaseURL: baseURL,
5757
OrgID: "",
5858
APIType: APITypeAzure,
5959
APIVersion: "2023-03-15-preview",
60-
Engine: engine,
60+
AzureModelMapperFunc: func(model string) string {
61+
return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
62+
},
6163

6264
HTTPClient: &http.Client{},
6365

@@ -68,3 +70,11 @@ func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
6870
func (ClientConfig) String() string {
6971
return "<OpenAI API ClientConfig>"
7072
}
73+
74+
func (c ClientConfig) GetAzureDeploymentByModel(model string) string {
75+
if c.AzureModelMapperFunc != nil {
76+
return c.AzureModelMapperFunc(model)
77+
}
78+
79+
return model
80+
}

config_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package openai_test
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/sashabaranov/go-openai"
7+
)
8+
9+
func TestGetAzureDeploymentByModel(t *testing.T) {
10+
cases := []struct {
11+
Model string
12+
AzureModelMapperFunc func(model string) string
13+
Expect string
14+
}{
15+
{
16+
Model: "gpt-3.5-turbo",
17+
Expect: "gpt-35-turbo",
18+
},
19+
{
20+
Model: "gpt-3.5-turbo-0301",
21+
Expect: "gpt-35-turbo-0301",
22+
},
23+
{
24+
Model: "text-embedding-ada-002",
25+
Expect: "text-embedding-ada-002",
26+
},
27+
{
28+
Model: "",
29+
Expect: "",
30+
},
31+
{
32+
Model: "models",
33+
Expect: "models",
34+
},
35+
{
36+
Model: "gpt-3.5-turbo",
37+
Expect: "my-gpt35",
38+
AzureModelMapperFunc: func(model string) string {
39+
modelmapper := map[string]string{
40+
"gpt-3.5-turbo": "my-gpt35",
41+
}
42+
if val, ok := modelmapper[model]; ok {
43+
return val
44+
}
45+
return model
46+
},
47+
},
48+
}
49+
50+
for _, c := range cases {
51+
t.Run(c.Model, func(t *testing.T) {
52+
conf := DefaultAzureConfig("", "https://test.openai.azure.com/")
53+
if c.AzureModelMapperFunc != nil {
54+
conf.AzureModelMapperFunc = c.AzureModelMapperFunc
55+
}
56+
actual := conf.GetAzureDeploymentByModel(c.Model)
57+
if actual != c.Expect {
58+
t.Errorf("Expected %s, got %s", c.Expect, actual)
59+
}
60+
})
61+
}
62+
}

edits.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
)
78

@@ -31,7 +32,7 @@ type EditsResponse struct {
3132

3233
// Perform an API call to the Edits endpoint.
3334
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
34-
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits"), request)
35+
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request)
3536
if err != nil {
3637
return
3738
}

embeddings.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ type EmbeddingRequest struct {
132132
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
133133
// https://beta.openai.com/docs/api-reference/embeddings/create
134134
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
135-
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings"), request)
135+
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request)
136136
if err != nil {
137137
return
138138
}

example_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,7 @@ func Example_chatbot() {
305305
func ExampleDefaultAzureConfig() {
306306
azureKey := os.Getenv("AZURE_OPENAI_API_KEY") // Your azure API key
307307
azureEndpoint := os.Getenv("AZURE_OPENAI_ENDPOINT") // Your azure OpenAI endpoint
308-
azureModel := os.Getenv("AZURE_OPENAI_MODEL") // Your model deployment name
309-
config := openai.DefaultAzureConfig(azureKey, azureEndpoint, azureModel)
308+
config := openai.DefaultAzureConfig(azureKey, azureEndpoint)
310309
client := openai.NewClientWithConfig(config)
311310
resp, err := client.CreateChatCompletion(
312311
context.Background(),

models_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestAzureListModels(t *testing.T) {
4040
ts.Start()
4141
defer ts.Close()
4242

43-
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/", "dummyengine")
43+
config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/")
4444
config.BaseURL = ts.URL
4545
client := NewClientWithConfig(config)
4646
ctx := context.Background()

moderation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ type ModerationResponse struct {
6363
// Moderations — perform a moderation api call over a string.
6464
// Input can be an array or slice but a string will reduce the complexity.
6565
func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) {
66-
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations"), request)
66+
req, err := c.requestBuilder.build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request)
6767
if err != nil {
6868
return
6969
}

stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (c *Client) CreateCompletionStream(
3535
}
3636

3737
request.Stream = true
38-
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request)
38+
req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model)
3939
if err != nil {
4040
return
4141
}

0 commit comments

Comments
 (0)