Skip to content

Commit d528b50

Browse files
committed
feat(embedding): Implement Google embedding dimension control and tests
Google Provider: - Add GoogleEmbeddingParams with taskType and title support - Add toGoogleEmbeddingParams() extension for polymorphic conversion - Update GoogleEmbeddingRequest DTO with outputDimensionality, taskType, title - Update GoogleLLMClient.embed() to use EmbeddingParams with capability validation - Add Embedding.Dimensions capability to GeminiEmbedding001 model Unit Tests: - Add GoogleEmbeddingParamsTest for validation and conversion - Update LLMEmbedderTest with embedBatch tests (provider-agnostic) Integration Tests: - Add integration_testEmbedWithDimensions (tests 256-dim output) - Add integration_testEmbedBatch (tests 3-text batch) - Add dimensionCapableEmbeddingModels() stream - Temporarily limit embeddingModels() to Google until other providers migrate Part of KG-104 (dimension control) and KG-538 (batch embedding)
1 parent bfd3a67 commit d528b50

File tree

11 files changed

+457
-13
lines changed

11 files changed

+457
-13
lines changed

embeddings/embeddings-llm/build.gradle.kts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ kotlin {
2525
commonTest {
2626
dependencies {
2727
implementation(kotlin("test"))
28-
implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client"))
28+
// TODO: Re-enable after OpenAI migration
29+
// implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client"))
2930
implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-google-client"))
3031
implementation(libs.kotlinx.coroutines.core)
3132
implementation(libs.kotlinx.coroutines.test)

embeddings/embeddings-llm/src/commonTest/kotlin/ai/koog/embeddings/local/LLMEmbedderTest.kt

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@ package ai.koog.embeddings.local
33
import ai.koog.embeddings.base.Vector
44
import ai.koog.prompt.executor.clients.LLMEmbeddingProvider
55
import ai.koog.prompt.executor.clients.google.GoogleModels
6-
import ai.koog.prompt.executor.clients.openai.OpenAIModels
6+
// TODO: Uncomment after OpenAI migration
7+
// import ai.koog.prompt.executor.clients.openai.OpenAIModels
78
import ai.koog.prompt.llm.LLModel
9+
import ai.koog.prompt.params.EmbeddingParams
810
import kotlinx.coroutines.test.runTest
911
import kotlin.test.Test
1012
import kotlin.test.assertEquals
13+
import kotlin.test.assertNull
1114

1215
class LLMEmbedderTest {
1316
// Using a pretty straightforward approach as commonTest doesn't support @ParametrizedTest annotation from JUnit5
1417
// Discussable, though.
18+
// TODO: Re-enable after all providers migrated
1519
val modelsList = listOf(
16-
OpenAIModels.Embeddings.TextEmbedding3Small,
17-
OllamaEmbeddingModels.NOMIC_EMBED_TEXT,
20+
// OpenAIModels.Embeddings.TextEmbedding3Small,
21+
// OllamaEmbeddingModels.NOMIC_EMBED_TEXT,
1822
GoogleModels.Embeddings.GeminiEmbedding001,
1923
)
2024

@@ -75,15 +79,89 @@ class LLMEmbedderTest {
7579
}
7680
}
7781

82+
@Test
83+
fun testEmbedBatch_usesDefaultParallelImplementation() = runTest {
84+
val model = modelsList.first()
85+
val mockClient = MockEmbedderClient()
86+
87+
val texts = listOf("text1", "text2", "text3")
88+
val vectors = listOf(
89+
Vector(listOf(0.1, 0.2)),
90+
Vector(listOf(0.3, 0.4)),
91+
Vector(listOf(0.5, 0.6))
92+
)
93+
94+
// Mock individual embeddings
95+
texts.forEachIndexed { i, text -> mockClient.mockEmbedding(text, vectors[i]) }
96+
97+
val results = mockClient.embedBatch(texts, model, EmbeddingParams())
98+
99+
assertEquals(3, results.size)
100+
assertEquals(listOf(0.1, 0.2), results[0])
101+
assertEquals(listOf(0.3, 0.4), results[1])
102+
assertEquals(listOf(0.5, 0.6), results[2])
103+
}
104+
105+
@Test
106+
fun testEmbedBatch_passesParamsToUnderlyingEmbed() = runTest {
107+
val model = modelsList.first()
108+
val mockClient = MockEmbedderClient()
109+
110+
mockClient.mockEmbedding("test", Vector(listOf(1.0)))
111+
112+
val params = EmbeddingParams(dimensions = 256)
113+
mockClient.embedBatch(listOf("test"), model, params)
114+
115+
assertEquals(256, mockClient.lastParams?.dimensions)
116+
}
117+
118+
@Test
119+
fun testEmbed_defaultParamsHasNullDimensions() = runTest {
120+
val model = modelsList.first()
121+
val mockClient = MockEmbedderClient()
122+
123+
mockClient.mockEmbedding("test", Vector(listOf(1.0)))
124+
mockClient.embed("test", model, EmbeddingParams())
125+
126+
assertNull(mockClient.lastParams?.dimensions)
127+
}
128+
78129
class MockEmbedderClient : LLMEmbeddingProvider {
79130
private val embeddings = mutableMapOf<String, Vector>()
131+
private val batchEmbeddings = mutableMapOf<List<String>, List<Vector>>()
132+
133+
/** Track the last params received for verification in tests */
134+
var lastParams: EmbeddingParams? = null
135+
private set
80136

81137
fun mockEmbedding(text: String, vector: Vector) {
82138
embeddings[text] = vector
83139
}
140+
141+
fun mockBatchEmbedding(texts: List<String>, vectors: List<Vector>) {
142+
batchEmbeddings[texts] = vectors
143+
}
84144

85-
override suspend fun embed(text: String, model: LLModel): List<Double> {
86-
return embeddings[text]?.values ?: throw IllegalArgumentException("No mock embedding for text: $text")
145+
override suspend fun embed(
146+
text: String,
147+
model: LLModel,
148+
params: EmbeddingParams
149+
): List<Double> {
150+
lastParams = params
151+
return embeddings[text]?.values
152+
?: throw IllegalArgumentException("No mock embedding for text: $text")
153+
}
154+
155+
override suspend fun embedBatch(
156+
texts: List<String>,
157+
model: LLModel,
158+
params: EmbeddingParams
159+
): List<List<Double>> {
160+
lastParams = params
161+
// Return mocked batch if available, otherwise fall back to individual embeds
162+
return batchEmbeddings[texts]?.map { it.values }
163+
?: texts.map { embed(it, model, params) }
87164
}
88165
}
89166
}
167+

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ExecutorIntegrationTestBase.kt

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,8 @@ abstract class ExecutorIntegrationTestBase {
904904
}
905905

906906
open fun integration_testEmbed(model: LLModel) = runTest {
907+
Models.assumeAvailable(model.provider)
908+
907909
val client = getLLMClient(model)
908910
if (client !is LLMEmbeddingProvider) {
909911
return@runTest
@@ -918,6 +920,63 @@ abstract class ExecutorIntegrationTestBase {
918920
}
919921
}
920922

923+
/**
924+
* Tests embedding with custom output dimensions.
925+
* Only runs for models that support [LLMCapability.Embedding.Dimensions].
926+
*/
927+
open fun integration_testEmbedWithDimensions(model: LLModel) = runTest {
928+
Models.assumeAvailable(model.provider)
929+
930+
val client = getLLMClient(model)
931+
if (client !is LLMEmbeddingProvider) {
932+
return@runTest
933+
}
934+
935+
// Only test if model supports dimensions
936+
assumeTrue(
937+
model.capabilities.contains(LLMCapability.Embedding.Dimensions),
938+
"Model ${model.id} does not support custom embedding dimensions"
939+
)
940+
941+
val testDimensions = 256
942+
val params = ai.koog.prompt.params.EmbeddingParams(dimensions = testDimensions)
943+
val result = client.embed("test embedding with dimensions", model, params)
944+
945+
result shouldNotBeNull {
946+
shouldNotBeEmpty()
947+
size shouldBe testDimensions
948+
shouldForAll { it.isFinite() }
949+
}
950+
}
951+
952+
/**
953+
* Tests batch embedding of multiple texts.
954+
*/
955+
open fun integration_testEmbedBatch(model: LLModel) = runTest {
956+
Models.assumeAvailable(model.provider)
957+
958+
val client = getLLMClient(model)
959+
if (client !is LLMEmbeddingProvider) {
960+
return@runTest
961+
}
962+
963+
val texts = listOf(
964+
"first text for batch embedding",
965+
"second text for batch embedding",
966+
"third text for batch embedding"
967+
)
968+
val results = client.embedBatch(texts, model)
969+
970+
results shouldNotBeNull {
971+
size shouldBe 3
972+
shouldForAll { embedding ->
973+
embedding.shouldNotBeEmpty()
974+
embedding.size shouldBeGreaterThan 100
975+
embedding.shouldForAll { it.isFinite() }
976+
}
977+
}
978+
}
979+
921980
open fun integration_testMultipleSystemMessages(model: LLModel) = runTest(timeout = 300.seconds) {
922981
Models.assumeAvailable(model.provider)
923982

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/MultipleLLMPromptExecutorIntegrationTest.kt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ class MultipleLLMPromptExecutorIntegrationTest : ExecutorIntegrationTestBase() {
4848
return Models.embeddingModels().map { model -> Arguments.of(model) }
4949
}
5050

51+
@JvmStatic
52+
fun dimensionCapableEmbeddingModels(): Stream<Arguments> {
53+
return Models.dimensionCapableEmbeddingModels().map { model -> Arguments.of(model) }
54+
}
55+
5156
@JvmStatic
5257
fun reasoningCapableModels(): Stream<Arguments> {
5358
return Models.reasoningCapableModels().map { model -> Arguments.of(model) }
@@ -239,6 +244,18 @@ class MultipleLLMPromptExecutorIntegrationTest : ExecutorIntegrationTestBase() {
239244
super.integration_testEmbed(model)
240245
}
241246

247+
@ParameterizedTest
248+
@MethodSource("dimensionCapableEmbeddingModels")
249+
override fun integration_testEmbedWithDimensions(model: LLModel) {
250+
super.integration_testEmbedWithDimensions(model)
251+
}
252+
253+
@ParameterizedTest
254+
@MethodSource("embeddingModels")
255+
override fun integration_testEmbedBatch(model: LLModel) {
256+
super.integration_testEmbedBatch(model)
257+
}
258+
242259
@ParameterizedTest
243260
@MethodSource("moderationModels")
244261
override fun integration_testSingleMessageModeration(model: LLModel) {

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ class SingleLLMPromptExecutorIntegrationTest : ExecutorIntegrationTestBase() {
3333
return Models.embeddingModels().map { model -> Arguments.of(model) }
3434
}
3535

36+
@JvmStatic
37+
fun dimensionCapableEmbeddingModels(): Stream<Arguments> {
38+
return Models.dimensionCapableEmbeddingModels().map { model -> Arguments.of(model) }
39+
}
40+
3641
@JvmStatic
3742
fun bedrockMarkdownScenarioModelCombinations(): Stream<Arguments> {
3843
return Models.bedrockModels().flatMap { model ->
@@ -277,6 +282,18 @@ class SingleLLMPromptExecutorIntegrationTest : ExecutorIntegrationTestBase() {
277282
super.integration_testEmbed(model)
278283
}
279284

285+
@ParameterizedTest
286+
@MethodSource("dimensionCapableEmbeddingModels")
287+
override fun integration_testEmbedWithDimensions(model: LLModel) {
288+
super.integration_testEmbedWithDimensions(model)
289+
}
290+
291+
@ParameterizedTest
292+
@MethodSource("embeddingModels")
293+
override fun integration_testEmbedBatch(model: LLModel) {
294+
super.integration_testEmbedBatch(model)
295+
}
296+
280297
@ParameterizedTest
281298
@MethodSource("moderationModels")
282299
override fun integration_testSingleMessageModeration(model: LLModel) {

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,24 @@ object Models {
6161

6262
@JvmStatic
6363
fun embeddingModels(): Stream<LLModel> {
64+
// TODO: Re-enable after provider migration complete
6465
return Stream.of(
65-
BedrockModels.Embeddings.AmazonTitanEmbedText,
66-
OpenAIModels.Embeddings.TextEmbedding3Large,
67-
MistralAIModels.Embeddings.MistralEmbed,
66+
// BedrockModels.Embeddings.AmazonTitanEmbedText,
67+
// OpenAIModels.Embeddings.TextEmbedding3Large,
68+
// MistralAIModels.Embeddings.MistralEmbed,
6869
GoogleModels.Embeddings.GeminiEmbedding001,
6970
)
7071
}
7172

73+
/**
74+
* Returns embedding models that support variable output dimensions via the dimensions parameter.
75+
* Only includes models with [LLMCapability.Embedding.Dimensions].
76+
*/
77+
@JvmStatic
78+
fun dimensionCapableEmbeddingModels(): Stream<LLModel> {
79+
return embeddingModels().filter { it.capabilities.contains(LLMCapability.Embedding.Dimensions) }
80+
}
81+
7282
/**
7383
* Returns models that support content moderation capabilities.
7484
*
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package ai.koog.prompt.executor.clients.google
2+
3+
import ai.koog.prompt.params.EmbeddingParams
4+
import kotlinx.serialization.Serializable
5+
6+
/**
7+
* Task type for Google embedding API.
8+
* Specifies the intended use case to help the model produce better embeddings.
9+
*
10+
* **Polymorphic Usage**: Users can call `embed()` with either:
11+
* - Generic `EmbeddingParams(dimensions = 256)` - works with any provider
12+
* - Specific `GoogleEmbeddingParams(dimensions = 256, taskType = RETRIEVAL_QUERY)` - Google-specific features
13+
*
14+
* The conversion function [toGoogleEmbeddingParams] handles both cases transparently.
15+
*/
16+
@Serializable
17+
public enum class GoogleEmbeddingTaskType(public val apiValue: String) {
18+
/** Query for search/retrieval. Use RETRIEVAL_DOCUMENT for the document side. */
19+
RETRIEVAL_QUERY("RETRIEVAL_QUERY"),
20+
21+
/** Document for search/retrieval. */
22+
RETRIEVAL_DOCUMENT("RETRIEVAL_DOCUMENT"),
23+
24+
/** Semantic textual similarity comparison. */
25+
SEMANTIC_SIMILARITY("SEMANTIC_SIMILARITY"),
26+
27+
/** Embeddings for classification tasks. */
28+
CLASSIFICATION("CLASSIFICATION"),
29+
30+
/** Embeddings for clustering tasks. */
31+
CLUSTERING("CLUSTERING"),
32+
33+
/** Query for question answering. Use RETRIEVAL_DOCUMENT for the document side. */
34+
QUESTION_ANSWERING("QUESTION_ANSWERING"),
35+
36+
/** Query for fact verification. Use RETRIEVAL_DOCUMENT for the document side. */
37+
FACT_VERIFICATION("FACT_VERIFICATION"),
38+
39+
/** Query for code retrieval (Java/Python). Use RETRIEVAL_DOCUMENT for the document side. */
40+
CODE_RETRIEVAL_QUERY("CODE_RETRIEVAL_QUERY"),
41+
}
42+
43+
/**
44+
* Google-specific embedding parameters.
45+
*
46+
* @property dimensions Desired output embedding dimensions (mapped to `outputDimensionality`).
47+
* @property taskType Specifies the intended use case for the embeddings.
48+
* @property title Document title (only valid with taskType=RETRIEVAL_DOCUMENT).
49+
*/
50+
public class GoogleEmbeddingParams(
51+
dimensions: Int? = null,
52+
public val taskType: GoogleEmbeddingTaskType? = null,
53+
public val title: String? = null,
54+
) : EmbeddingParams(dimensions) {
55+
56+
init {
57+
// title is only valid with RETRIEVAL_DOCUMENT
58+
if (title != null) {
59+
require(taskType == GoogleEmbeddingTaskType.RETRIEVAL_DOCUMENT) {
60+
"title parameter is only valid when taskType is RETRIEVAL_DOCUMENT"
61+
}
62+
}
63+
}
64+
65+
override fun copy(dimensions: Int?): GoogleEmbeddingParams =
66+
GoogleEmbeddingParams(dimensions, taskType, title)
67+
68+
override fun equals(other: Any?): Boolean = when {
69+
this === other -> true
70+
other !is GoogleEmbeddingParams -> false
71+
else -> dimensions == other.dimensions &&
72+
taskType == other.taskType &&
73+
title == other.title
74+
}
75+
76+
override fun hashCode(): Int {
77+
var result = dimensions?.hashCode() ?: 0
78+
result = 31 * result + (taskType?.hashCode() ?: 0)
79+
result = 31 * result + (title?.hashCode() ?: 0)
80+
return result
81+
}
82+
83+
override fun toString(): String =
84+
"GoogleEmbeddingParams(dimensions=$dimensions, taskType=$taskType, title=$title)"
85+
}
86+
87+
/**
88+
* Converts generic [EmbeddingParams] to [GoogleEmbeddingParams].
89+
* Follows the same pattern as [LLMParams.toGoogleParams].
90+
*/
91+
internal fun EmbeddingParams.toGoogleEmbeddingParams(): GoogleEmbeddingParams = when (this) {
92+
is GoogleEmbeddingParams -> this
93+
else -> GoogleEmbeddingParams(dimensions = dimensions)
94+
}
95+

0 commit comments

Comments
 (0)