Skip to content

Commit 1a82f3c

Browse files
authored
fix(google): classify embedding models by supportedGenerationMethods (#14)
* fix(google): classify embedding models by supportedGenerationMethods ListModels was hardcoding all Google models as sdk.ModelTypeChat, discarding the supportedGenerationMethods field already parsed in googleModelObject. Models with only "embedContent" (no "generateContent") are now correctly classified as sdk.ModelTypeEmbedding. Also adds ModelTypeEmbedding constant to the sdk package. Ref: memohai/Memoh#533 * docs: add ModelTypeEmbedding to API reference and skill docs
1 parent a9dadaa commit 1a82f3c

5 files changed

Lines changed: 67 additions & 4 deletions

File tree

docs/api-reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ type Model struct {
111111
}
112112

113113
type ModelType string
114-
const ModelTypeChat ModelType = "chat"
114+
const ModelTypeChat ModelType = "chat"
115+
const ModelTypeEmbedding ModelType = "embedding"
115116
```
116117

117118
#### Methods

provider/google/generativeai/generativeai.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func (p *Provider) ListModels(ctx context.Context) ([]sdk.Model, error) {
7474
ID: id,
7575
DisplayName: m.DisplayName,
7676
Provider: p,
77-
Type: sdk.ModelTypeChat,
77+
Type: googleModelType(m.SupportedGenerationMethods),
7878
})
7979
}
8080
return models, nil
@@ -134,6 +134,23 @@ func (p *Provider) ChatModel(id string) *sdk.Model {
134134
}
135135
}
136136

137+
func googleModelType(methods []string) sdk.ModelType {
138+
hasGenerate := false
139+
hasEmbed := false
140+
for _, m := range methods {
141+
switch m {
142+
case "generateContent":
143+
hasGenerate = true
144+
case "embedContent":
145+
hasEmbed = true
146+
}
147+
}
148+
if hasEmbed && !hasGenerate {
149+
return sdk.ModelTypeEmbedding
150+
}
151+
return sdk.ModelTypeChat
152+
}
153+
137154
// ---------- DoGenerate ----------
138155

139156
func (p *Provider) DoGenerate(ctx context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) { //nolint:gocritic // interface method

provider/google/generativeai/generativeai_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,49 @@ func TestListModels(t *testing.T) {
11061106
if models[0].DisplayName != "Gemini 2.5 Pro" {
11071107
t.Errorf("expected display name 'Gemini 2.5 Pro', got %q", models[0].DisplayName)
11081108
}
1109+
if models[0].Type != sdk.ModelTypeChat {
1110+
t.Errorf("expected chat type for model without supportedGenerationMethods, got %q", models[0].Type)
1111+
}
1112+
}
1113+
1114+
func TestListModels_EmbeddingType(t *testing.T) {
1115+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1116+
w.Header().Set("Content-Type", "application/json")
1117+
json.NewEncoder(w).Encode(map[string]any{
1118+
"models": []map[string]any{
1119+
{
1120+
"name": "models/gemini-2.5-flash",
1121+
"displayName": "Gemini 2.5 Flash",
1122+
"supportedGenerationMethods": []string{"generateContent", "countTokens"},
1123+
},
1124+
{
1125+
"name": "models/gemini-embedding-001",
1126+
"displayName": "Gemini Embedding 001",
1127+
"supportedGenerationMethods": []string{"embedContent", "countTokens"},
1128+
},
1129+
},
1130+
})
1131+
}))
1132+
defer srv.Close()
1133+
1134+
p := generativeai.New(
1135+
generativeai.WithAPIKey("test-key"),
1136+
generativeai.WithBaseURL(srv.URL),
1137+
)
1138+
1139+
models, err := p.ListModels(context.Background())
1140+
if err != nil {
1141+
t.Fatalf("ListModels failed: %v", err)
1142+
}
1143+
if len(models) != 2 {
1144+
t.Fatalf("expected 2 models, got %d", len(models))
1145+
}
1146+
if models[0].Type != sdk.ModelTypeChat {
1147+
t.Errorf("expected chat type for generateContent model, got %q", models[0].Type)
1148+
}
1149+
if models[1].Type != sdk.ModelTypeEmbedding {
1150+
t.Errorf("expected embedding type for embedContent model, got %q", models[1].Type)
1151+
}
11091152
}
11101153

11111154
func TestProviderTest_OK(t *testing.T) {

sdk/model.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import "context"
55
type ModelType string
66

77
const (
8-
ModelTypeChat ModelType = "chat"
8+
ModelTypeChat ModelType = "chat"
9+
ModelTypeEmbedding ModelType = "embedding"
910
)
1011

1112
type Model struct {

skill/reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ type ModelTestResult struct {
7979
```go
8080
type ModelType string
8181

82-
const ModelTypeChat ModelType = "chat"
82+
const ModelTypeChat ModelType = "chat"
83+
const ModelTypeEmbedding ModelType = "embedding"
8384

8485
type Model struct {
8586
ID string

0 commit comments

Comments
 (0)