Skip to content

Commit 0196370

Browse files
committed
allow one to specify width of vector in echo embedding
this will be useful for testing as not all embedding models emit a vector of the same size
1 parent ce503ed commit 0196370

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

Diff for: pkg/aicli/echoai.go

+24-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package aicli
22

33
import (
44
"io"
5+
"strconv"
6+
"strings"
57

68
"github.com/pkg/errors"
79
)
@@ -31,10 +33,30 @@ func (c *Echo) GetEmbedding(req *EmbeddingRequest) ([]Embedding, error) {
3133
if len(req.Inputs) == 0 {
3234
return nil, errors.New("emtpy list of inputs")
3335
}
36+
// parse req.Model to see what size vector to return. Format is name_num.
37+
split := strings.Split(req.Model, "_")
38+
var vectorSize int
39+
switch len(split) {
40+
case 0:
41+
return nil, errors.New("empty model name")
42+
case 1:
43+
vectorSize = 1
44+
default:
45+
num, err := strconv.Atoi(split[1])
46+
if err != nil {
47+
return nil, errors.Wrap(err, "parsing model vector size")
48+
}
49+
vectorSize = num
50+
}
51+
if vectorSize <= 0 {
52+
return nil, errors.Errorf("vector size must be greater than 0, got %d", vectorSize)
53+
}
3454
ret := make([]Embedding, len(req.Inputs))
3555
for i := range ret {
36-
ret[i].Embedding = []float32{0.42}
56+
ret[i].Embedding = make([]float32, vectorSize)
57+
for j := range ret[i].Embedding {
58+
ret[i].Embedding[j] = 0.42
59+
}
3760
}
3861
return ret, nil
39-
4062
}

Diff for: pkg/aicli/echoai_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,41 @@ func TestEcho(t *testing.T) {
4444
})
4545
}
4646
}
47+
48+
func TestEchoEmbeddings(t *testing.T) {
49+
c := &aicli.Echo{}
50+
cases := []struct {
51+
model string
52+
exp []float32
53+
expErr string
54+
}{
55+
{
56+
model: "basic",
57+
exp: []float32{0.42},
58+
},
59+
{
60+
model: "numbered_2",
61+
exp: []float32{0.42, 0.42},
62+
},
63+
{
64+
model: "numbered_-11",
65+
expErr: "size must be greater than 0",
66+
},
67+
}
68+
69+
for i, tst := range cases {
70+
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
71+
er := &aicli.EmbeddingRequest{
72+
Model: tst.model,
73+
Inputs: []string{"hello"},
74+
}
75+
emb, err := c.GetEmbedding(er)
76+
if tst.expErr != "" {
77+
require.Contains(t, err.Error(), tst.expErr)
78+
} else {
79+
require.NoError(t, err)
80+
require.Equal(t, tst.exp, emb[0].Embedding)
81+
}
82+
})
83+
}
84+
}

0 commit comments

Comments
 (0)