Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ kind: tool
name: insert_embedding
type: postgres-sql
source: my-pg-instance
description: Insert a new document into the database.
statement: |
INSERT INTO documents (content, embedding)
VALUES ($1, $2);
Expand All @@ -92,6 +93,7 @@ kind: tool
name: search_embedding
type: postgres-sql
source: my-pg-instance
description: Search for documents in the database.
statement: |
SELECT id, content, embedding <-> $1 AS distance
FROM documents
Expand Down
63 changes: 43 additions & 20 deletions docs/en/documentation/configuration/embedding-models/gemini.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ title: "Gemini Embedding"
type: docs
weight: 1
description: >
Use Google's Gemini models to generate high-performance text embeddings for vector databases.
Use Google's Gemini models to generate high-performance text embeddings for
vector databases.
---

## About
Expand All @@ -13,15 +14,17 @@ high-dimensional vectors.

### Authentication

Toolbox uses your [Application Default Credentials
(ADC)][adc] to authorize with the
Gemini API client.
Toolbox supports two authentication modes:

Optionally, you can use an [API key][api-key] obtain an API
Key from the [Google AI Studio][ai-studio].
1. **Google AI (API Key):** Used if you
provide `apiKey` (or set `GOOGLE_API_KEY`/`GEMINI_API_KEY` environment
variables). This uses the [Google AI Studio][ai-studio] backend.
2. **Vertex AI (ADC):** Used if provided `project` and `location` (or set
`GOOGLE_CLOUD_PROJECT`/`GOOGLE_CLOUD_LOCATION` environment variables). This uses [Application
Default Credentials (ADC)][adc].

We recommend using an API key for testing and using application default
credentials for production.
We recommend using an API key for quick testing and using Vertex AI with ADC for
production environments.

[adc]: https://cloud.google.com/docs/authentication#adc
[api-key]: https://ai.google.dev/gemini-api/docs/api-key#api-keys
Expand All @@ -41,14 +44,19 @@ to your database source.
The `dimension` field must match the expected size of your database column
(e.g., a `vector(768)` column in PostgreSQL). This setting is supported by newer
models since 2024 only. You cannot set this value if using the earlier model
(`models/embedding-001`). Check out [available Gemini models][modellist] for more
information.
(`models/embedding-001`). Check out [available Gemini models][modellist] for
more information.

[modellist]:
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models

## Example

### Using Google AI

Google AI uses API Key for authentication. You can get an API key from [Google
AI Studio][ai-studio].

```yaml
kind: embeddingModel
name: gemini-model
Expand All @@ -58,16 +66,31 @@ apiKey: ${GOOGLE_API_KEY}
dimension: 768
```

{{< notice tip >}}
Use environment variable replacement with the format ${ENV_NAME}
instead of hardcoding your secrets into the configuration file.
### Using Vertex AI

Vertex AI uses Application Default Credentials (ADC) for authentication. Learn
how to set up ADC [here][adc].

```yaml
kind: embeddingModel
name: gemini-model
type: gemini
model: gemini-embedding-001
project: ${GOOGLE_CLOUD_PROJECT}
location: us-central1
dimension: 768
```

[adc]: https://docs.cloud.google.com/docs/authentication/provide-credentials-adc

{{< notice tip >}} Use environment variable replacement with the format
${ENV_NAME} instead of hardcoding your secrets into the configuration file.
{{< /notice >}}

## Reference

| **field** | **type** | **required** | **description** |
|-----------|:--------:|:------------:|--------------------------------------------------------------|
| type | string | true | Must be `gemini`. |
| model | string | true | The Gemini model ID to use (e.g., `gemini-embedding-001`). |
| apiKey | string | false | Your API Key from Google AI Studio. |
| dimension | integer | false | The number of dimensions in the output vector (e.g., `768`). |
| **field** | **type** | **required** | **description** |
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
| type | string | true | Must be `gemini`. |
| model | string | true | The Gemini model ID to use (e.g., `gemini-embedding-001`). |
| dimension | integer | false | The number of dimensions in the output vector (e.g., `768`). |
64 changes: 57 additions & 7 deletions internal/embeddingmodels/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"net/http"
"os"

"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/util"
Expand All @@ -34,6 +35,8 @@ type Config struct {
Type string `yaml:"type" validate:"required"`
Model string `yaml:"model" validate:"required"`
ApiKey string `yaml:"apiKey"`
Project string `yaml:"project"`
Location string `yaml:"location"`
Dimension int32 `yaml:"dimension"`
}

Expand All @@ -44,12 +47,60 @@ func (cfg Config) EmbeddingModelConfigType() string {

// Initialize a Gemini embedding model
func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingModel, error) {
// Get client configs
configs := &genai.ClientConfig{}
if cfg.ApiKey != "" {
configs.APIKey = cfg.ApiKey

// Retrieve logger from context
l, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to retrieve logger: %w", err)
}

// Get API Key
apiKey := cfg.ApiKey
if apiKey == "" {
apiKey = os.Getenv("GOOGLE_API_KEY")
}
if apiKey == "" {
apiKey = os.Getenv("GEMINI_API_KEY")
}

// Try to resolve Project and Location
project := cfg.Project
if project == "" {
project = os.Getenv("GOOGLE_CLOUD_PROJECT")
}

location := cfg.Location
if location == "" {
location = os.Getenv("GOOGLE_CLOUD_LOCATION")
}

// Determine the Backend
if project != "" && location != "" {
// VertexAI API uses ADC for authentication.
// ADC requires `Project` and `Location` to be set.
configs.Backend = genai.BackendVertexAI
configs.Project = project
configs.Location = location

l.InfoContext(ctx, "Using Vertex AI backend for Gemini embedding", "project", project, "location", location)

} else if apiKey != "" {
// Using Gemini API, which uses API Key for authentication.
configs.Backend = genai.BackendGeminiAPI
configs.APIKey = apiKey

l.InfoContext(ctx, "Using Google AI (Gemini API) backend for Gemini embedding")

} else {
// Missing credentials
return nil, fmt.Errorf("missing credentials for Gemini embedding: " +
"For Google AI: Provide 'apiKey' in YAML or set GOOGLE_API_KEY/GEMINI_API_KEY env vars. " +
"For Vertex AI: Provide 'project'/'location' in YAML or via GOOGLE_CLOUD_PROJECT/GOOGLE_CLOUD_LOCATION env vars. " +
"See documentation for details: https://googleapis.github.io/genai-toolbox/resources/embeddingmodels/gemini/")
}

// Set user agent
ua, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get user agent from context: %w", err)
Expand All @@ -63,14 +114,13 @@ func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingMode
// Create new Gemini API client
client, err := genai.NewClient(ctx, configs)
if err != nil {
return nil, fmt.Errorf("unable to create Gemini API client")
return nil, fmt.Errorf("unable to create Gemini API client: %w", err)
}

m := &EmbeddingModel{
return &EmbeddingModel{
Config: cfg,
Client: client,
}
return m, nil
}, nil
}

var _ embeddingmodels.EmbeddingModel = EmbeddingModel{}
Expand Down
75 changes: 61 additions & 14 deletions internal/embeddingmodels/gemini/gemini_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package gemini_test

import (
"context"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -37,36 +38,58 @@ func TestParseFromYamlGemini(t *testing.T) {
kind: embeddingModel
name: my-gemini-model
type: gemini
model: text-embedding-004
model: gemini-embedding-001
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"my-gemini-model": gemini.Config{
Name: "my-gemini-model",
Type: gemini.EmbeddingModelType,
Model: "text-embedding-004",
Model: "gemini-embedding-001",
},
},
},
{
desc: "full example with optional fields",
desc: "full example with Google AI fields",
in: `
kind: embeddingModel
name: complex-gemini
type: gemini
model: text-embedding-004
model: gemini-embedding-001
apiKey: "test-api-key"
dimension: 768
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"complex-gemini": gemini.Config{
Name: "complex-gemini",
Type: gemini.EmbeddingModelType,
Model: "text-embedding-004",
Model: "gemini-embedding-001",
ApiKey: "test-api-key",
Dimension: 768,
},
},
},
{
desc: "Vertex AI configuration",
in: `
kind: embeddingModel
name: vertex-gemini
type: gemini
model: gemini-embedding-001
project: "my-project"
location: "us-central1"
dimension: 512
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"vertex-gemini": gemini.Config{
Name: "vertex-gemini",
Type: gemini.EmbeddingModelType,
Model: "gemini-embedding-001",
Project: "my-project",
Location: "us-central1",
Dimension: 512,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
Expand All @@ -81,6 +104,7 @@ func TestParseFromYamlGemini(t *testing.T) {
})
}
}

func TestFailParseFromYamlGemini(t *testing.T) {
tcs := []struct {
desc string
Expand All @@ -94,7 +118,6 @@ func TestFailParseFromYamlGemini(t *testing.T) {
name: bad-model
type: gemini
`,
// Removed the specific model name from the prefix to match your output
err: "error unmarshaling embeddingModel: unable to parse as \"bad-model\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
},
{
Expand All @@ -103,21 +126,45 @@ func TestFailParseFromYamlGemini(t *testing.T) {
kind: embeddingModel
name: bad-field
type: gemini
model: text-embedding-004
model: gemini-embedding-001
invalid_param: true
`,
// Updated to match the specific line-starting format of your error output
err: "error unmarshaling embeddingModel: unable to parse as \"bad-field\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | model: text-embedding-004\n 3 | name: bad-field\n 4 | type: gemini",
err: "error unmarshaling embeddingModel: unable to parse as \"bad-field\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | model: gemini-embedding-001\n 3 | name: bad-field\n 4 | type: gemini",
},
{
desc: "missing both Vertex and Google AI credentials",
in: `
kind: embeddingModel
name: missing-creds
type: gemini
model: text-embedding-004
`,
err: "unable to initialize embedding model \"missing-creds\": missing credentials for Gemini embedding: For Google AI: Provide 'apiKey' in YAML or set GOOGLE_API_KEY/GEMINI_API_KEY env vars. For Vertex AI: Provide 'project'/'location' in YAML or via GOOGLE_CLOUD_PROJECT/GOOGLE_CLOUD_LOCATION env vars. See documentation for details: https://googleapis.github.io/genai-toolbox/resources/embeddingmodels/gemini/",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err == nil {
t.Fatalf("expect parsing to fail")
t.Setenv("GOOGLE_API_KEY", "")
t.Setenv("GEMINI_API_KEY", "")
t.Setenv("GOOGLE_CLOUD_PROJECT", "")
t.Setenv("GOOGLE_CLOUD_LOCATION", "")

_, embeddingConfigs, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
if err != nil {
if err.Error() != tc.err {
t.Fatalf("unexpected unmarshal error:\ngot: %q\nwant: %q", err.Error(), tc.err)
}
return
}
if err.Error() != tc.err {
t.Fatalf("unexpected error:\ngot: %q\nwant: %q", err.Error(), tc.err)

for _, cfg := range embeddingConfigs {
_, err = cfg.Initialize()
if err == nil {
t.Fatalf("expect initialization to fail for case: %s", tc.desc)
}
if !strings.Contains(err.Error(), tc.err) {
t.Fatalf("unexpected init error:\ngot: %q\nwant: %q", err.Error(), tc.err)
}
}
})
}
Expand Down
Loading