-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Expand file tree
/
Copy pathgemini.go
More file actions
133 lines (108 loc) · 3.6 KB
/
gemini.go
File metadata and controls
133 lines (108 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gemini
import (
"context"
"fmt"
"net/http"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/util"
"google.golang.org/genai"
)
const EmbeddingModelType string = "gemini"
// validate interface
var _ embeddingmodels.EmbeddingModelConfig = Config{}
type Config struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Model string `yaml:"model" validate:"required"`
ApiKey string `yaml:"apiKey"`
Dimension int32 `yaml:"dimension"`
}
// Returns the embedding model type
func (cfg Config) EmbeddingModelConfigType() string {
return EmbeddingModelType
}
// Initialize a Gemini embedding model
func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingModel, error) {
// Get client configs
configs := &genai.ClientConfig{}
if cfg.ApiKey != "" {
configs.APIKey = cfg.ApiKey
}
ua, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user agent from context: %w", err)
}
configs.HTTPOptions = genai.HTTPOptions{
Headers: http.Header{
"User-Agent": []string{ua},
},
}
// Create new Gemini API client
client, err := genai.NewClient(ctx, configs)
if err != nil {
return nil, fmt.Errorf("unable to create Gemini API client: %w", err)
}
m := &EmbeddingModel{
Config: cfg,
Client: client,
}
return m, nil
}
var _ embeddingmodels.EmbeddingModel = EmbeddingModel{}
type EmbeddingModel struct {
Client *genai.Client
Config
}
// Returns the embedding model type
func (m EmbeddingModel) EmbeddingModelType() string {
return EmbeddingModelType
}
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {
return m.Config
}
func (m EmbeddingModel) EmbedParameters(ctx context.Context, parameters []string) ([][]float32, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
contents := convertStringsToContents(parameters)
embedConfig := &genai.EmbedContentConfig{
TaskType: "SEMANTIC_SIMILARITY",
}
if m.Dimension > 0 {
embedConfig.OutputDimensionality = genai.Ptr(m.Dimension)
}
result, err := m.Client.Models.EmbedContent(ctx, m.Model, contents, embedConfig)
if err != nil {
logger.ErrorContext(ctx, "Error calling EmbedContent for model %s: %v", m.Model, err)
return nil, err
}
embeddings := make([][]float32, 0, len(result.Embeddings))
for _, embedding := range result.Embeddings {
embeddings = append(embeddings, embedding.Values)
}
logger.InfoContext(ctx, "Successfully embedded %d text parameters using model %s", len(parameters), m.Model)
return embeddings, nil
}
// convertStringsToContents takes a slice of strings and converts it into a slice of *genai.Content objects.
func convertStringsToContents(texts []string) []*genai.Content {
contents := make([]*genai.Content, 0, len(texts))
for _, text := range texts {
content := genai.NewContentFromText(text, "")
contents = append(contents, content)
}
return contents
}