Skip to content

Commit e01a2d7

Browse files
authored
convert EmbeddingModel to string type (#629)
This gives the user the ability to pass in models for embeddings that are not already defined in the library. Also more closely matches how the completions API works.
1 parent 682b7ad commit e01a2d7

File tree

2 files changed

+24
-118
lines changed

2 files changed

+24
-118
lines changed

embeddings.go

Lines changed: 21 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch")
1313

1414
// EmbeddingModel enumerates the models which can be used
1515
// to generate Embedding vectors.
16-
type EmbeddingModel int
17-
18-
// String implements the fmt.Stringer interface.
19-
func (e EmbeddingModel) String() string {
20-
return enumToString[e]
21-
}
22-
23-
// MarshalText implements the encoding.TextMarshaler interface.
24-
func (e EmbeddingModel) MarshalText() ([]byte, error) {
25-
return []byte(e.String()), nil
26-
}
27-
28-
// UnmarshalText implements the encoding.TextUnmarshaler interface.
29-
// On unrecognized value, it sets |e| to Unknown.
30-
func (e *EmbeddingModel) UnmarshalText(b []byte) error {
31-
if val, ok := stringToEnum[(string(b))]; ok {
32-
*e = val
33-
return nil
34-
}
35-
36-
*e = Unknown
37-
38-
return nil
39-
}
16+
type EmbeddingModel string
4017

4118
const (
42-
Unknown EmbeddingModel = iota
43-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
44-
AdaSimilarity
45-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
46-
BabbageSimilarity
47-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
48-
CurieSimilarity
49-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
50-
DavinciSimilarity
51-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
52-
AdaSearchDocument
53-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
54-
AdaSearchQuery
55-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
56-
BabbageSearchDocument
57-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
58-
BabbageSearchQuery
59-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
60-
CurieSearchDocument
61-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
62-
CurieSearchQuery
63-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
64-
DavinciSearchDocument
65-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
66-
DavinciSearchQuery
67-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
68-
AdaCodeSearchCode
69-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
70-
AdaCodeSearchText
71-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
72-
BabbageCodeSearchCode
73-
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
74-
BabbageCodeSearchText
75-
AdaEmbeddingV2
19+
// Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
20+
AdaSimilarity EmbeddingModel = "text-similarity-ada-001"
21+
BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001"
22+
CurieSimilarity EmbeddingModel = "text-similarity-curie-001"
23+
DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001"
24+
AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001"
25+
AdaSearchQuery EmbeddingModel = "text-search-ada-query-001"
26+
BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001"
27+
BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001"
28+
CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001"
29+
CurieSearchQuery EmbeddingModel = "text-search-curie-query-001"
30+
DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001"
31+
DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001"
32+
AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001"
33+
AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001"
34+
BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001"
35+
BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001"
36+
37+
AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002"
7638
)
7739

78-
var enumToString = map[EmbeddingModel]string{
79-
AdaSimilarity: "text-similarity-ada-001",
80-
BabbageSimilarity: "text-similarity-babbage-001",
81-
CurieSimilarity: "text-similarity-curie-001",
82-
DavinciSimilarity: "text-similarity-davinci-001",
83-
AdaSearchDocument: "text-search-ada-doc-001",
84-
AdaSearchQuery: "text-search-ada-query-001",
85-
BabbageSearchDocument: "text-search-babbage-doc-001",
86-
BabbageSearchQuery: "text-search-babbage-query-001",
87-
CurieSearchDocument: "text-search-curie-doc-001",
88-
CurieSearchQuery: "text-search-curie-query-001",
89-
DavinciSearchDocument: "text-search-davinci-doc-001",
90-
DavinciSearchQuery: "text-search-davinci-query-001",
91-
AdaCodeSearchCode: "code-search-ada-code-001",
92-
AdaCodeSearchText: "code-search-ada-text-001",
93-
BabbageCodeSearchCode: "code-search-babbage-code-001",
94-
BabbageCodeSearchText: "code-search-babbage-text-001",
95-
AdaEmbeddingV2: "text-embedding-ada-002",
96-
}
97-
98-
var stringToEnum = map[string]EmbeddingModel{
99-
"text-similarity-ada-001": AdaSimilarity,
100-
"text-similarity-babbage-001": BabbageSimilarity,
101-
"text-similarity-curie-001": CurieSimilarity,
102-
"text-similarity-davinci-001": DavinciSimilarity,
103-
"text-search-ada-doc-001": AdaSearchDocument,
104-
"text-search-ada-query-001": AdaSearchQuery,
105-
"text-search-babbage-doc-001": BabbageSearchDocument,
106-
"text-search-babbage-query-001": BabbageSearchQuery,
107-
"text-search-curie-doc-001": CurieSearchDocument,
108-
"text-search-curie-query-001": CurieSearchQuery,
109-
"text-search-davinci-doc-001": DavinciSearchDocument,
110-
"text-search-davinci-query-001": DavinciSearchQuery,
111-
"code-search-ada-code-001": AdaCodeSearchCode,
112-
"code-search-ada-text-001": AdaCodeSearchText,
113-
"code-search-babbage-code-001": BabbageCodeSearchCode,
114-
"code-search-babbage-text-001": BabbageCodeSearchText,
115-
"text-embedding-ada-002": AdaEmbeddingV2,
116-
}
117-
11840
// Embedding is a special format of data representation that can be easily utilized by machine
11941
// learning models and algorithms. The embedding is an information dense representation of the
12042
// semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
@@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings(
306228
conv EmbeddingRequestConverter,
307229
) (res EmbeddingResponse, err error) {
308230
baseReq := conv.Convert()
309-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
231+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq))
310232
if err != nil {
311233
return
312234
}

embeddings_test.go

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
4747
// the AdaSearchQuery type
4848
marshaled, err := json.Marshal(embeddingReq)
4949
checks.NoError(t, err, "Could not marshal embedding request")
50-
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
50+
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
5151
t.Fatalf("Expected embedding request to contain model field")
5252
}
5353

@@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
6161
}
6262
marshaled, err = json.Marshal(embeddingReqStrings)
6363
checks.NoError(t, err, "Could not marshal embedding request")
64-
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
64+
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
6565
t.Fatalf("Expected embedding request to contain model field")
6666
}
6767

@@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
7575
}
7676
marshaled, err = json.Marshal(embeddingReqTokens)
7777
checks.NoError(t, err, "Could not marshal embedding request")
78-
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
78+
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
7979
t.Fatalf("Expected embedding request to contain model field")
8080
}
8181
}
8282
}
8383

84-
func TestEmbeddingModel(t *testing.T) {
85-
var em openai.EmbeddingModel
86-
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
87-
checks.NoError(t, err, "Could not marshal embedding model")
88-
89-
if em != openai.AdaSimilarity {
90-
t.Errorf("Model is not equal to AdaSimilarity")
91-
}
92-
93-
err = em.UnmarshalText([]byte("some-non-existent-model"))
94-
checks.NoError(t, err, "Could not marshal embedding model")
95-
if em != openai.Unknown {
96-
t.Errorf("Model is not equal to Unknown")
97-
}
98-
}
99-
10084
func TestEmbeddingEndpoint(t *testing.T) {
10185
client, server, teardown := setupOpenAITestServer()
10286
defer teardown()

0 commit comments

Comments
 (0)