Skip to content

Commit 9fa1df1

Browse files
mltheuserMalte Heuser
andauthored
KG-314 Add GeminiEmbedding to GoogleModels.kt (#1235)
Related to [KG-314](https://youtrack.jetbrains.com/issue/KG-314) ## Motivation and Context This PR enables text embedding support for Google's Gemini models within the Koog framework, resolving issue #713. The `GoogleLLMClient` has been updated to implement the `LLMEmbeddingProvider` interface, allowing users to generate vector embeddings alongside existing chat functionality. This includes the addition of the `gemini-embedding-001` model definition and the necessary API integration for the `embedContent` endpoint. **Testing:** Functionality has been verified by adding the Gemini embedding model to the standard embedding integration test suite (`integration_testEmbed`). ## Breaking Changes None. --- #### Type of the changes - [x] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [x] An issue describing the proposed change exists - [x] The pull request includes a link to the issue - [x] The change was discussed and approved in the issue - [ ] Docs have been added / updated --------- Co-authored-by: Malte Heuser <[email protected]>
1 parent 68227c5 commit 9fa1df1

File tree

8 files changed

+87
-1
lines changed

8 files changed

+87
-1
lines changed

embeddings/embeddings-llm/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ kotlin {
2626
dependencies {
2727
implementation(kotlin("test"))
2828
implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client"))
29+
implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-google-client"))
2930
implementation(libs.kotlinx.coroutines.core)
3031
implementation(libs.kotlinx.coroutines.test)
3132
}

embeddings/embeddings-llm/src/commonTest/kotlin/ai/koog/embeddings/local/LLMEmbedderTest.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ai.koog.embeddings.local
22

33
import ai.koog.embeddings.base.Vector
44
import ai.koog.prompt.executor.clients.LLMEmbeddingProvider
5+
import ai.koog.prompt.executor.clients.google.GoogleModels
56
import ai.koog.prompt.executor.clients.openai.OpenAIModels
67
import ai.koog.prompt.llm.LLModel
78
import kotlinx.coroutines.test.runTest
@@ -14,6 +15,7 @@ class LLMEmbedderTest {
1415
val modelsList = listOf(
1516
OpenAIModels.Embeddings.TextEmbedding3Small,
1617
OllamaEmbeddingModels.NOMIC_EMBED_TEXT,
18+
GoogleModels.Embeddings.GeminiEmbedding001,
1719
)
1820

1921
@Test

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ object Models {
6464
BedrockModels.Embeddings.AmazonTitanEmbedText,
6565
OpenAIModels.Embeddings.TextEmbedding3Large,
6666
MistralAIModels.Embeddings.MistralEmbed,
67+
GoogleModels.Embeddings.GeminiEmbedding001,
6768
)
6869
}
6970

koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ private val GOOGLE_MODELS_MAP = mapOf(
260260
"gemini2_5pro" to GoogleModels.Gemini2_5Pro,
261261
"gemini2_5flash" to GoogleModels.Gemini2_5Flash,
262262
"gemini2_5flashlite" to GoogleModels.Gemini2_5FlashLite,
263+
"gemini_embedding001" to GoogleModels.Embeddings.GeminiEmbedding001,
263264
)
264265

265266
private val MISTRAL_MODELS_MAP = mapOf(

koog-ktor/src/commonTest/kotlin/ai/koog/ktor/ModelIdentifierParsingTest.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,11 @@ class ModelIdentifierParsingTest {
249249
assertNotNull(gemini25FlashLite)
250250
assertEquals(LLMProvider.Google, gemini25FlashLite.provider)
251251
assertEquals(GoogleModels.Gemini2_5FlashLite, gemini25FlashLite)
252+
253+
val geminiEmbedding001 = getModelFromIdentifier("google.gemini_embedding001")
254+
assertNotNull(geminiEmbedding001)
255+
assertEquals(LLMProvider.Google, geminiEmbedding001.provider)
256+
assertEquals(GoogleModels.Embeddings.GeminiEmbedding001, geminiEmbedding001)
252257
}
253258

254259
// MistralAI model identifier tests

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ import ai.koog.prompt.dsl.Prompt
1010
import ai.koog.prompt.executor.clients.ConnectionTimeoutConfig
1111
import ai.koog.prompt.executor.clients.LLMClient
1212
import ai.koog.prompt.executor.clients.LLMClientException
13+
import ai.koog.prompt.executor.clients.LLMEmbeddingProvider
1314
import ai.koog.prompt.executor.clients.google.models.GoogleCandidate
1415
import ai.koog.prompt.executor.clients.google.models.GoogleContent
1516
import ai.koog.prompt.executor.clients.google.models.GoogleData
17+
import ai.koog.prompt.executor.clients.google.models.GoogleEmbeddingRequest
18+
import ai.koog.prompt.executor.clients.google.models.GoogleEmbeddingResponse
1619
import ai.koog.prompt.executor.clients.google.models.GoogleFunctionCallingConfig
1720
import ai.koog.prompt.executor.clients.google.models.GoogleFunctionCallingMode
1821
import ai.koog.prompt.executor.clients.google.models.GoogleFunctionDeclaration
@@ -79,6 +82,7 @@ public class GoogleClientSettings(
7982
public val defaultPath: String = "v1beta/models",
8083
public val generateContentMethod: String = "generateContent",
8184
public val streamGenerateContentMethod: String = "streamGenerateContent",
85+
public val embedContentMethod: String = "embedContent"
8286
)
8387

8488
/**
@@ -97,7 +101,7 @@ public open class GoogleLLMClient(
97101
private val settings: GoogleClientSettings = GoogleClientSettings(),
98102
baseClient: HttpClient = HttpClient(),
99103
private val clock: Clock = Clock.System
100-
) : LLMClient {
104+
) : LLMClient, LLMEmbeddingProvider {
101105

102106
@OptIn(InternalStructuredOutputApi::class)
103107
private companion object {
@@ -757,4 +761,38 @@ public open class GoogleLLMClient(
757761
override fun close() {
758762
httpClient.close()
759763
}
764+
765+
override suspend fun embed(text: String, model: LLModel): List<Double> {
766+
require(model.capabilities.contains(LLMCapability.Embed)) {
767+
"Model ${model.id} does not support embedding."
768+
}
769+
770+
logger.debug { "Embedding text with model: ${model.id}" }
771+
772+
val request = GoogleEmbeddingRequest(
773+
model = "models/${model.id}",
774+
content = GoogleContent(
775+
parts = listOf(GooglePart.Text(text))
776+
)
777+
)
778+
779+
try {
780+
val response = httpClient.post(
781+
path = "${settings.defaultPath}/${model.id}:${settings.embedContentMethod}",
782+
request = request,
783+
requestBodyType = GoogleEmbeddingRequest::class,
784+
responseType = GoogleEmbeddingResponse::class,
785+
)
786+
787+
return response.embedding.values
788+
} catch (e: CancellationException) {
789+
throw e
790+
} catch (e: Exception) {
791+
throw LLMClientException(
792+
clientName = clientName,
793+
message = e.message,
794+
cause = e
795+
)
796+
}
797+
}
760798
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleModels.kt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,23 @@ public object GoogleModels : LLModelDefinitions {
150150
contextLength = 1_048_576,
151151
maxOutputTokens = 65_536,
152152
)
153+
154+
/**
155+
* Models for generating text embeddings.
156+
*/
157+
public object Embeddings {
158+
/**
159+
* Gemini embedding model for generating embeddings for words, phrases, and sentences.
160+
*
161+
* Input token limit: 2048
162+
*
163+
* @see <a href="https://ai.google.dev/gemini-api/docs/embeddings#model-versions">
164+
*/
165+
public val GeminiEmbedding001: LLModel = LLModel(
166+
provider = LLMProvider.Google,
167+
id = "gemini-embedding-001",
168+
capabilities = listOf(LLMCapability.Embed),
169+
contextLength = 2048,
170+
)
171+
}
153172
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package ai.koog.prompt.executor.clients.google.models
2+
3+
import kotlinx.serialization.Serializable
4+
5+
@Serializable
6+
internal data class GoogleEmbeddingRequest(
7+
val model: String,
8+
val content: GoogleContent
9+
)
10+
11+
@Serializable
12+
internal data class GoogleEmbeddingResponse(
13+
val embedding: GoogleEmbeddingData
14+
)
15+
16+
@Serializable
17+
internal data class GoogleEmbeddingData(
18+
val values: List<Double>
19+
)

0 commit comments

Comments
 (0)