Skip to content

Commit 992ce50

Browse files
authored
feat: support the encodingFormat in embedding models (#1035)
**Description** This PR is support EncodingFormat in embedding models **Related Issues/PRs (if applicable)** Fixes #1034 Signed-off-by: misakazhou <[email protected]>
1 parent cdc8d37 commit 992ce50

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

internal/apischema/openai/openai.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,13 +1113,43 @@ type Embedding struct {
11131113
// Object: The object type, which is always "embedding".
11141114
Object string `json:"object"`
11151115

1116-
// Embedding: The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide.
1117-
Embedding []float64 `json:"embedding"`
1116+
// Embedding: The embedding vector, which can be a list of floats or a string.
1117+
// The length of vector depends on the model as listed in the embedding guide.
1118+
Embedding EmbeddingUnion `json:"embedding"`
11181119

11191120
// Index: The index of the embedding in the list of embeddings.
11201121
Index int `json:"index"`
11211122
}
11221123

1124+
// EmbeddingUnion is a union type that can handle both []float64 and string formats.
1125+
type EmbeddingUnion struct {
1126+
Value interface{}
1127+
}
1128+
1129+
// UnmarshalJSON implements json.Unmarshaler to handle both []float64 and string formats.
1130+
func (e *EmbeddingUnion) UnmarshalJSON(data []byte) error {
1131+
// Try to unmarshal as []float64 first.
1132+
var floats []float64
1133+
if err := json.Unmarshal(data, &floats); err == nil {
1134+
e.Value = floats
1135+
return nil
1136+
}
1137+
1138+
// Try to unmarshal as string.
1139+
var str string
1140+
if err := json.Unmarshal(data, &str); err == nil {
1141+
e.Value = str
1142+
return nil
1143+
}
1144+
1145+
return errors.New("embedding must be either []float64 or string")
1146+
}
1147+
1148+
// MarshalJSON implements json.Marshaler.
1149+
func (e EmbeddingUnion) MarshalJSON() ([]byte, error) {
1150+
return json.Marshal(e.Value)
1151+
}
1152+
11231153
// EmbeddingUsage represents the usage information for an embeddings request.
11241154
// https://platform.openai.com/docs/api-reference/embeddings/object#embeddings/object-usage
11251155
type EmbeddingUsage struct {

internal/apischema/openai/openai_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,3 +969,45 @@ func TestChatCompletionResponseChunkChoice(t *testing.T) {
969969
require.JSONEq(t, expected, string(jsonData))
970970
})
971971
}
972+
973+
func TestEmbeddingUnionUnmarshal(t *testing.T) {
974+
tests := []struct {
975+
name string
976+
input string
977+
want interface{}
978+
wantErr bool
979+
}{
980+
{
981+
name: "unmarshal array of floats",
982+
input: `[1.0, 2.0, 3.0]`,
983+
want: []float64{1.0, 2.0, 3.0},
984+
},
985+
{
986+
name: "unmarshal string",
987+
input: `"base64response"`,
988+
want: "base64response",
989+
},
990+
{
991+
name: "unmarshal int should error",
992+
input: `123`,
993+
wantErr: true,
994+
},
995+
}
996+
997+
for _, tt := range tests {
998+
t.Run(tt.name, func(t *testing.T) {
999+
var eu EmbeddingUnion
1000+
err := json.Unmarshal([]byte(tt.input), &eu)
1001+
if (err != nil) != tt.wantErr {
1002+
t.Errorf("EmbeddingUnion Unmarshal Error. error = %v, wantErr %v", err, tt.wantErr)
1003+
return
1004+
}
1005+
if !tt.wantErr {
1006+
// Use reflect.DeepEqual to compare
1007+
if !cmp.Equal(eu.Value, tt.want) {
1008+
t.Errorf("EmbeddingUnion Unmarshal Error. got = %v, want %v", eu.Value, tt.want)
1009+
}
1010+
}
1011+
})
1012+
}
1013+
}

0 commit comments

Comments
 (0)