@@ -158,6 +158,32 @@ func TestEmbeddingEndpoint(t *testing.T) {
158
158
checks .HasError (t , err , "CreateEmbeddings error" )
159
159
}
160
160
161
+ func TestAzureEmbeddingEndpoint (t * testing.T ) {
162
+ client , server , teardown := setupAzureTestServer ()
163
+ defer teardown ()
164
+
165
+ sampleEmbeddings := []openai.Embedding {
166
+ {Embedding : []float32 {1.23 , 4.56 , 7.89 }},
167
+ {Embedding : []float32 {- 0.006968617 , - 0.0052718227 , 0.011901081 }},
168
+ }
169
+
170
+ server .RegisterHandler (
171
+ "/openai/deployments/text-embedding-ada-002/embeddings" ,
172
+ func (w http.ResponseWriter , r * http.Request ) {
173
+ resBytes , _ := json .Marshal (openai.EmbeddingResponse {Data : sampleEmbeddings })
174
+ fmt .Fprintln (w , string (resBytes ))
175
+ },
176
+ )
177
+ // test create embeddings with strings (simple embedding request)
178
+ res , err := client .CreateEmbeddings (context .Background (), openai.EmbeddingRequest {
179
+ Model : openai .AdaEmbeddingV2 ,
180
+ })
181
+ checks .NoError (t , err , "CreateEmbeddings error" )
182
+ if ! reflect .DeepEqual (res .Data , sampleEmbeddings ) {
183
+ t .Errorf ("Expected %#v embeddings, got %#v" , sampleEmbeddings , res .Data )
184
+ }
185
+ }
186
+
161
187
func TestEmbeddingResponseBase64_ToEmbeddingResponse (t * testing.T ) {
162
188
type fields struct {
163
189
Object string
0 commit comments