Skip to content

Commit bfd3a67

Browse files
committed
feat(embedding): Add EmbeddingParams and batch embedding support to LLMEmbeddingProvider
- Add LLMCapability.Embedding.Dimensions for models supporting variable output dimensions - Add EmbeddingParams base class in prompt-model with dimensions parameter - Update LLMEmbeddingProvider interface: - Add EmbeddingParams parameter to embed() method - Add embedBatch() with default parallel polyfill implementation - Apply minimal signature updates to all provider clients (OpenAI, Mistral, Ollama, Bedrock) to enable compilation (full implementation pending) Part of KG-104 (dimension control) and KG-538 (batch embedding)
1 parent 3c0cca2 commit bfd3a67

File tree

7 files changed

+125
-6
lines changed
  • prompt
    • prompt-executor/prompt-executor-clients
      • prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock
      • prompt-executor-mistralai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/mistralai
      • prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client
      • prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai
      • src/commonMain/kotlin/ai/koog/prompt/executor/clients
    • prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm
    • prompt-model/src/commonMain/kotlin/ai/koog/prompt/params

7 files changed

+125
-6
lines changed

prompt/prompt-executor/prompt-executor-clients/prompt-executor-bedrock-client/src/jvmMain/kotlin/ai/koog/prompt/executor/clients/bedrock/BedrockLLMClient.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,11 @@ public class BedrockLLMClient(
354354
}
355355
}
356356

357-
override suspend fun embed(text: String, model: LLModel): List<Double> {
357+
override suspend fun embed(
358+
text: String,
359+
model: LLModel,
360+
params: ai.koog.prompt.params.EmbeddingParams
361+
): List<Double> {
358362
model.requireCapability(LLMCapability.Embed)
359363
logger.debug { "Embedding text with model: ${model.id}" }
360364
val modelFamily = getBedrockModelFamily(model)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-mistralai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/mistralai/MistralAILLMClient.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,15 @@ public open class MistralAILLMClient(
177177
*
178178
* @param text The text to embed.
179179
* @param model The model to use for embedding. Must have the Embed capability.
180+
* @param params Embedding parameters (dimensions support TODO in MistralAI migration)
180181
* @return A list of floating-point values representing the embedding.
181182
* @throws IllegalArgumentException if the model does not have the Embed capability.
182183
*/
183-
override suspend fun embed(text: String, model: LLModel): List<Double> {
184+
override suspend fun embed(
185+
text: String,
186+
model: LLModel,
187+
params: ai.koog.prompt.params.EmbeddingParams
188+
): List<Double> {
184189
model.requireCapability(LLMCapability.Embed)
185190

186191
logger.debug { "Embedding text with model: ${model.id}" }

prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,15 @@ public class OllamaClient(
330330
*
331331
* @param text The text to embed.
332332
* @param model The model to use for embedding. Must have the Embed capability.
333+
* @param params Embedding parameters (dimensions support TODO in Ollama migration)
333334
* @return A vector representation of the text.
334335
* @throws LLMClientException if the model does not have the Embed capability.
335336
*/
336-
override suspend fun embed(text: String, model: LLModel): List<Double> {
337+
override suspend fun embed(
338+
text: String,
339+
model: LLModel,
340+
params: ai.koog.prompt.params.EmbeddingParams
341+
): List<Double> {
337342
require(model.provider == LLMProvider.Ollama) { "Model not supported by Ollama" }
338343

339344
if (!model.capabilities.contains(LLMCapability.Embed)) {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,15 @@ public open class OpenAILLMClient(
398398
*
399399
* @param text The text to embed.
400400
* @param model The model to use for embedding. Must have the Embed capability.
401+
* @param params Embedding parameters (dimensions support TODO in OpenAI migration)
401402
* @return A list of floating-point values representing the embedding.
402403
* @throws IllegalArgumentException if the model does not have the Embed capability.
403404
*/
404-
override suspend fun embed(text: String, model: LLModel): List<Double> {
405+
override suspend fun embed(
406+
text: String,
407+
model: LLModel,
408+
params: ai.koog.prompt.params.EmbeddingParams
409+
): List<Double> {
405410
model.requireCapability(LLMCapability.Embed)
406411

407412
logger.debug { "Embedding text with model: ${model.id}" }
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,52 @@
11
package ai.koog.prompt.executor.clients
22

33
import ai.koog.prompt.llm.LLModel
4+
import ai.koog.prompt.params.EmbeddingParams
5+
import kotlinx.coroutines.async
6+
import kotlinx.coroutines.awaitAll
7+
import kotlinx.coroutines.coroutineScope
48

59
/**
610
* Extension of the LLMClient interface which includes functionality for generating text embeddings
711
* in addition to executing prompts and streaming outputs.
812
*/
913
public interface LLMEmbeddingProvider {
1014
/**
11-
* Embeds the given text using into a vector of double-precision numbers.
15+
* Embeds the given text into a vector of double-precision numbers.
1216
*
1317
* @param text The text to embed.
1418
* @param model The model to use for embedding. Must have the Embed capability.
19+
* @param params Optional embedding parameters (e.g., dimensions).
1520
* @return A list of floating-point values representing the embedding.
21+
* @throws IllegalArgumentException if the model does not have the Embed capability,
22+
* or if dimensions are specified but the model lacks Embedding.Dimensions capability.
23+
*/
24+
public suspend fun embed(
25+
text: String,
26+
model: LLModel,
27+
params: EmbeddingParams = EmbeddingParams()
28+
): List<Double>
29+
30+
/**
31+
* Embeds multiple texts in a batch.
32+
*
33+
* Default implementation processes texts in parallel using single [embed] calls.
34+
* Providers with native batch APIs should override this for better performance.
35+
*
36+
* @param texts The list of texts to embed.
37+
* @param model The model to use for embedding. Must have the Embed capability.
38+
* @param params Optional embedding parameters (e.g., dimensions).
39+
* @return A list of embeddings, one for each input text.
1640
* @throws IllegalArgumentException if the model does not have the Embed capability.
1741
*/
18-
public suspend fun embed(text: String, model: LLModel): List<Double>
42+
public suspend fun embedBatch(
43+
texts: List<String>,
44+
model: LLModel,
45+
params: EmbeddingParams = EmbeddingParams()
46+
): List<List<Double>> = coroutineScope {
47+
texts.map { text ->
48+
async { embed(text, model, params) }
49+
}.awaitAll()
50+
}
1951
}
52+

prompt/prompt-llm/src/commonMain/kotlin/ai/koog/prompt/llm/LLMCapability.kt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,34 @@ public sealed class LLMCapability(public val id: String) {
119119
*
120120
* This capability can be utilized in tasks like semantic search, document clustering,
121121
* or other operations requiring an understanding of textual similarity.
122+
*
123+
* Note: For models that support additional embedding features (like variable dimensions),
124+
* see [Embedding] sealed class for more granular capabilities.
122125
*/
123126
@Serializable
124127
public data object Embed : LLMCapability("embed")
125128

129+
/**
130+
* Represents embedding-specific capabilities beyond basic embedding generation.
131+
*
132+
* These capabilities are **additive** to the base [Embed] capability. A model should have
133+
* both [Embed] and the relevant [Embedding] sub-capability. For example:
134+
* - OpenAI text-embedding-ada-002: has [Embed] only
135+
* - OpenAI text-embedding-3-small: has [Embed] AND [Embedding.Dimensions]
136+
*
137+
* @property embeddingFeature The specific embedding feature identifier.
138+
*/
139+
@Serializable
140+
public sealed class Embedding(public val embeddingFeature: String) : LLMCapability(embeddingFeature) {
141+
/**
142+
* Indicates that the model supports variable output dimensions for embeddings.
143+
* Models with this capability can accept a `dimensions` parameter to control
144+
* the size of the output embedding vector.
145+
*/
146+
@Serializable
147+
public data object Dimensions : Embedding("embedding-dimensions")
148+
}
149+
126150
/**
127151
* Represents the "completion" capability for Language Learning Models (LLMs). This capability
128152
* typically encompasses the generation of text or content based on the given input context.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package ai.koog.prompt.params
2+
3+
import kotlinx.serialization.Serializable
4+
5+
/**
6+
* Parameters for embedding generation.
7+
*
8+
* This is an `open class` (not a `data class`) to allow provider-specific subclasses
9+
* (e.g., [ai.koog.prompt.executor.clients.google.GoogleEmbeddingParams]) to add additional
10+
* parameters while maintaining polymorphism. This mirrors the [LLMParams] pattern.
11+
*
12+
* @property dimensions Desired output embedding dimensions.
13+
* Only applicable to models that support variable dimensions
14+
* (models with [ai.koog.prompt.llm.LLMCapability.Embedding.Dimensions] capability).
15+
* If null, uses model's default dimension.
16+
*/
17+
@Serializable
18+
public open class EmbeddingParams(
19+
public val dimensions: Int? = null,
20+
) {
21+
init {
22+
dimensions?.let {
23+
require(it > 0) { "dimensions must be > 0, but was $it" }
24+
}
25+
}
26+
27+
/**
28+
* Creates a copy of this instance with the ability to modify the dimensions property.
29+
*/
30+
public open fun copy(dimensions: Int? = this.dimensions): EmbeddingParams =
31+
EmbeddingParams(dimensions)
32+
33+
override fun equals(other: Any?): Boolean = when {
34+
this === other -> true
35+
other !is EmbeddingParams -> false
36+
else -> dimensions == other.dimensions
37+
}
38+
39+
override fun hashCode(): Int = dimensions?.hashCode() ?: 0
40+
41+
override fun toString(): String = "EmbeddingParams(dimensions=$dimensions)"
42+
}
43+

0 commit comments

Comments
 (0)