Skip to content

Commit 48a66d9

Browse files
mltheuserMalte Heuser
andauthored
KG-596 Fix thought signature on tool call with gemini 3.0 (#1317)
## Motivation and Context Related to: [KG-596](https://youtrack.jetbrains.com/projects/KG/issues/KG-596) Gemini 3.0 models enforce stricter validation of [thoughtSignature](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#401-429) for function calls. When the model returns parallel tool calls, only the *first* call in a turn receives a [thoughtSignature](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#401-429). On subsequent turns, the API expects this signature to be echoed back exactly. Without proper handling, multi-turn agentic conversations with parallel tools fail with cryptic API errors. **The Problem:** The [GoogleLLMClient](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt#101-844) wasn't preserving [thoughtSignature](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#401-429) across turns, and wasn't correctly re-grouping parallel tool calls/results when constructing requests—leading to malformed conversation structures that newer Gemini models reject. ## How It's Solved 1. **Preserve [thoughtSignature](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#401-429):** When processing model responses, we now extract the [thoughtSignature](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#401-429) from `GooglePart.FunctionCall` and store it in `Message.Tool.Call.metaInfo.metadata`. When building subsequent requests, we restore it. 2. **Correctly batch parallel tool calls/results:** The [createGoogleRequest](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt#280-477) function now uses a buffering strategy to re-group interleaved messages. The key insight: if a tool call has a [thoughtSignature](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#401-429), it starts a new turn; if it doesn't, it's a parallel call in the same turn. This lets us batch calls into a single [model](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt#775-805) role [GoogleContent](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/models/GoogleGenerateContent.kt#48-53) and results into a single [user](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#59-74) role [GoogleContent](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/models/GoogleGenerateContent.kt#48-53), as the API requires. 3. **Clean, idiomatic implementation:** The buffering logic uses a [when](file:///Users/ku76uh/Developer/jetbrains/fork/koog/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt#59-74) expression to make the three states explicit (new turn, starting fresh, parallel call), keeping the code readable and maintainable. ## Breaking Changes None. This is a backward-compatible fix that enables correct behavior with Gemini 3.0+ while remaining compatible with earlier models. --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [x] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [x] An issue describing the proposed change exists - [x] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --------- Co-authored-by: Malte Heuser <[email protected]>
1 parent 6a25db1 commit 48a66d9

File tree

3 files changed

+237
-45
lines changed
  • integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils
  • prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src

3 files changed

+237
-45
lines changed

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ object Models {
3535
@JvmStatic
3636
fun googleModels(): Stream<LLModel> {
3737
return Stream.of(
38+
GoogleModels.Gemini3_Pro_Preview,
3839
GoogleModels.Gemini2_5Pro,
3940
GoogleModels.Gemini2_5Flash,
4041
)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt

Lines changed: 72 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ public open class GoogleLLMClient(
287287
val systemMessageParts = mutableListOf<GooglePart.Text>()
288288
val contents = mutableListOf<GoogleContent>()
289289
val pendingCalls = mutableListOf<GooglePart.FunctionCall>()
290+
val pendingResults = mutableListOf<GooglePart.FunctionResponse>()
291+
var lastSignature: String? = null
290292

291293
fun flushCalls() {
292294
if (pendingCalls.isNotEmpty()) {
@@ -295,20 +297,32 @@ public open class GoogleLLMClient(
295297
}
296298
}
297299

300+
fun flushResults() {
301+
if (pendingResults.isNotEmpty()) {
302+
contents += GoogleContent(role = "user", parts = pendingResults.toList())
303+
pendingResults.clear()
304+
}
305+
}
306+
307+
fun flushAll() {
308+
flushCalls()
309+
flushResults()
310+
}
311+
298312
for (message in prompt.messages) {
299313
when (message) {
300314
is Message.System -> {
301315
systemMessageParts.add(GooglePart.Text(message.content))
302316
}
303317

304318
is Message.User -> {
305-
flushCalls()
319+
flushAll()
306320
// User messages become 'user' role content
307321
contents.add(message.toGoogleContent(model))
308322
}
309323

310324
is Message.Assistant -> {
311-
flushCalls()
325+
flushAll()
312326
contents.add(
313327
GoogleContent(
314328
role = "model",
@@ -318,51 +332,64 @@ public open class GoogleLLMClient(
318332
}
319333

320334
is Message.Reasoning -> {
321-
flushCalls()
322-
contents.add(
323-
GoogleContent(
324-
role = "model",
325-
parts = listOf(
326-
GooglePart.Text(
327-
text = message.content,
328-
thoughtSignature = message.encrypted,
329-
thought = true,
335+
// Reasoning indicates a new step - flush previous step
336+
flushAll()
337+
338+
if (message.content.isNotBlank()) {
339+
// If content is present, it's a "Thought Summary" -> Convert to Text part with thought=true
340+
contents.add(
341+
GoogleContent(
342+
role = "model",
343+
parts = listOf(
344+
GooglePart.Text(
345+
text = message.content,
346+
thought = true,
347+
thoughtSignature = message.encrypted
348+
)
330349
)
331350
)
332351
)
333-
)
352+
} else {
353+
// If content is empty/blank, it's strictly a signature carrier for the next Tool.Call
354+
lastSignature = message.encrypted
355+
}
334356
}
335357

336358
is Message.Tool.Result -> {
337-
flushCalls()
338-
contents.add(
339-
GoogleContent(
340-
role = "user",
341-
parts = listOf(
342-
GooglePart.FunctionResponse(
343-
functionResponse = GoogleData.FunctionResponse(
344-
id = message.id,
345-
name = message.tool,
346-
response = buildJsonObject { put("result", message.content) }
347-
)
348-
)
359+
// Just buffer results. We only flush when we know the current tool turn is complete.
360+
pendingResults.add(
361+
GooglePart.FunctionResponse(
362+
functionResponse = GoogleData.FunctionResponse(
363+
id = message.id,
364+
name = message.tool,
365+
response = buildJsonObject { put("result", message.content) }
349366
)
350367
)
351368
)
352369
}
353370

354371
is Message.Tool.Call -> {
372+
// First call in step needs to flush stale results
373+
if (pendingCalls.isEmpty()) {
374+
flushResults()
375+
}
376+
377+
// Use signature from preceding Reasoning message
378+
val signature = lastSignature
379+
lastSignature = null // Consume: only first call gets the signature
380+
355381
pendingCalls += GooglePart.FunctionCall(
356382
functionCall = GoogleData.FunctionCall(
357383
id = message.id,
358384
name = message.tool,
359385
args = json.decodeFromString(message.content)
360-
)
386+
),
387+
thoughtSignature = signature
361388
)
362389
}
363390
}
364391
}
365-
flushCalls()
392+
flushAll()
366393

367394
val googleTools = tools
368395
.map { tool ->
@@ -599,23 +626,21 @@ public open class GoogleLLMClient(
599626
val responses = mutableListOf<Message.Response>()
600627
with(responses) {
601628
parts.forEach { part ->
602-
if (part.thoughtSignature != null && part.thought == false) {
603-
add(
604-
Message.Reasoning(
605-
encrypted = part.thoughtSignature,
606-
content = "",
607-
metaInfo = metaInfo
608-
)
609-
)
629+
// Create Reasoning for any part with signature (signature carrier),
630+
// unless the part itself is a thought (in which case it carries the signature)
631+
val signature = part.thoughtSignature
632+
val isThought = part.thought == true
633+
if (signature != null && !isThought) {
634+
add(Message.Reasoning(encrypted = signature, content = "", metaInfo = metaInfo))
610635
}
611636

612637
when (part) {
613638
is GooglePart.Text -> {
614-
if (part.thought ?: false) {
639+
if (isThought) {
615640
add(
616641
Message.Reasoning(
617-
encrypted = part.thoughtSignature,
618642
content = part.text,
643+
encrypted = signature,
619644
metaInfo = metaInfo
620645
)
621646
)
@@ -630,14 +655,16 @@ public open class GoogleLLMClient(
630655
}
631656
}
632657

633-
is GooglePart.FunctionCall -> add(
634-
Message.Tool.Call(
635-
id = Uuid.random().toString(),
636-
tool = part.functionCall.name,
637-
content = part.functionCall.args.toString(),
638-
metaInfo = metaInfo
658+
is GooglePart.FunctionCall -> {
659+
add(
660+
Message.Tool.Call(
661+
id = Uuid.random().toString(),
662+
tool = part.functionCall.name,
663+
content = part.functionCall.args.toString(),
664+
metaInfo = metaInfo
665+
)
639666
)
640-
)
667+
}
641668

642669
is GooglePart.InlineData -> {
643670
val inlineData = part.inlineData
@@ -669,8 +696,8 @@ public open class GoogleLLMClient(
669696
}
670697

671698
return when {
672-
// Fix the situation when the model decides to both call tools and talk
673-
responses.any { it is Message.Tool.Call } -> responses.filterIsInstance<Message.Tool.Call>()
699+
// When the model calls tools, keep Reasoning (for signature) and Tool.Call, filter out Assistant text
700+
responses.any { it is Message.Tool.Call } -> responses.filter { it is Message.Reasoning || it is Message.Tool.Call }
674701
// If no messages where returned, return an empty message and check finishReason
675702
responses.isEmpty() -> listOf(
676703
Message.Assistant(

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ import ai.koog.prompt.executor.clients.google.models.GoogleThinkingConfig
1313
import ai.koog.prompt.message.AttachmentContent
1414
import ai.koog.prompt.message.ContentPart
1515
import ai.koog.prompt.message.Message
16+
import ai.koog.prompt.message.RequestMetaInfo
1617
import ai.koog.prompt.message.ResponseMetaInfo
1718
import ai.koog.prompt.params.LLMParams
1819
import io.kotest.matchers.collections.shouldContain
1920
import io.kotest.matchers.collections.shouldHaveSize
2021
import io.kotest.matchers.shouldBe
2122
import io.kotest.matchers.shouldNotBe
23+
import io.kotest.matchers.types.shouldBeInstanceOf
2224
import kotlinx.serialization.json.JsonObject
2325
import kotlinx.serialization.json.JsonPrimitive
26+
import kotlinx.serialization.json.buildJsonObject
2427
import kotlinx.serialization.json.jsonArray
2528
import kotlinx.serialization.json.jsonObject
2629
import kotlinx.serialization.json.jsonPrimitive
@@ -384,4 +387,165 @@ class GoogleLLMClientTest {
384387
filePart.mimeType shouldBe "application/pdf"
385388
(filePart.content as AttachmentContent.Binary.Bytes).asBytes() shouldBe fileData
386389
}
390+
391+
@Test
392+
fun `createGoogleRequest groups parallel Tool Results into single content`() {
393+
val client = GoogleLLMClient(apiKey = "test")
394+
val request = client.createGoogleRequest(
395+
Prompt(
396+
messages = listOf(
397+
Message.User("query", RequestMetaInfo.Empty),
398+
Message.Reasoning(encrypted = "sig", content = "", metaInfo = ResponseMetaInfo.Empty),
399+
Message.Tool.Call(id = "1", tool = "t1", content = "{}", metaInfo = ResponseMetaInfo.Empty),
400+
Message.Tool.Call(id = "2", tool = "t2", content = "{}", metaInfo = ResponseMetaInfo.Empty),
401+
Message.Tool.Result(id = "1", tool = "t1", content = "r1", metaInfo = RequestMetaInfo.Empty),
402+
Message.Tool.Result(id = "2", tool = "t2", content = "r2", metaInfo = RequestMetaInfo.Empty),
403+
),
404+
id = "id"
405+
),
406+
GoogleModels.Gemini3_Pro_Preview,
407+
emptyList()
408+
)
409+
410+
// Structure: User, FunctionCalls(grouped), FunctionResponses(grouped)
411+
request.contents shouldHaveSize 3
412+
request.contents[0].role shouldBe "user"
413+
request.contents[1].role shouldBe "model"
414+
request.contents[2].role shouldBe "user"
415+
416+
// FunctionResponses are grouped
417+
val responsesParts = request.contents[2].parts!!
418+
responsesParts shouldHaveSize 2
419+
responsesParts.forEach { it.shouldBeInstanceOf<GooglePart.FunctionResponse>() }
420+
}
421+
422+
@Test
423+
fun `createGoogleRequest attaches signature from Reasoning to first call only`() {
424+
val client = GoogleLLMClient(apiKey = "test")
425+
val request = client.createGoogleRequest(
426+
Prompt(
427+
messages = listOf(
428+
Message.User("query", RequestMetaInfo.Empty),
429+
Message.Reasoning(encrypted = "my-sig", content = "", metaInfo = ResponseMetaInfo.Empty),
430+
Message.Tool.Call(id = "1", tool = "t1", content = "{}", metaInfo = ResponseMetaInfo.Empty),
431+
Message.Tool.Call(id = "2", tool = "t2", content = "{}", metaInfo = ResponseMetaInfo.Empty),
432+
),
433+
id = "id"
434+
),
435+
GoogleModels.Gemini3_Pro_Preview,
436+
emptyList()
437+
)
438+
439+
val callsParts = request.contents[1].parts!!
440+
callsParts shouldHaveSize 2
441+
442+
val fc1 = callsParts[0] as GooglePart.FunctionCall
443+
val fc2 = callsParts[1] as GooglePart.FunctionCall
444+
445+
fc1.thoughtSignature shouldBe "my-sig" // First gets signature
446+
fc2.thoughtSignature shouldBe null // Second doesn't
447+
}
448+
449+
@Test
450+
fun `processGoogleCandidate creates Reasoning before FunctionCall with signature`() {
451+
val client = GoogleLLMClient(apiKey = "test")
452+
val candidate = GoogleCandidate(
453+
content = GoogleContent(
454+
role = "model",
455+
parts = listOf(
456+
GooglePart.FunctionCall(
457+
functionCall = GoogleData.FunctionCall(name = "tool", args = buildJsonObject {}),
458+
thoughtSignature = "sig-123"
459+
)
460+
)
461+
),
462+
finishReason = "STOP"
463+
)
464+
465+
val responses = client.processGoogleCandidate(candidate, ResponseMetaInfo.Empty)
466+
467+
responses shouldHaveSize 2
468+
responses[0].shouldBeInstanceOf<Message.Reasoning>()
469+
responses[1].shouldBeInstanceOf<Message.Tool.Call>()
470+
(responses[0] as Message.Reasoning).encrypted shouldBe "sig-123"
471+
(responses[0] as Message.Reasoning).content shouldBe ""
472+
}
473+
474+
@Test
475+
fun `processGoogleCandidate creates Reasoning from Text with thought=true`() {
476+
val client = GoogleLLMClient(apiKey = "test")
477+
val candidate = GoogleCandidate(
478+
content = GoogleContent(
479+
role = "model",
480+
parts = listOf(
481+
GooglePart.Text(
482+
text = "I am thinking...",
483+
thought = true,
484+
thoughtSignature = "thought-sig"
485+
)
486+
)
487+
),
488+
finishReason = "STOP"
489+
)
490+
491+
val responses = client.processGoogleCandidate(candidate, ResponseMetaInfo.Empty)
492+
493+
responses shouldHaveSize 1
494+
responses[0].shouldBeInstanceOf<Message.Reasoning>()
495+
val reasoning = responses[0] as Message.Reasoning
496+
reasoning.content shouldBe "I am thinking..."
497+
reasoning.encrypted shouldBe "thought-sig"
498+
}
499+
500+
@Test
501+
fun `createGoogleRequest includes Reasoning as Text part with thought=true`() {
502+
val client = GoogleLLMClient(apiKey = "test")
503+
val request = client.createGoogleRequest(
504+
Prompt(
505+
messages = listOf(
506+
Message.User("query", RequestMetaInfo.Empty),
507+
Message.Reasoning(content = "Previous thought", encrypted = "prev-sig", metaInfo = ResponseMetaInfo.Empty)
508+
),
509+
id = "id"
510+
),
511+
GoogleModels.Gemini3_Pro_Preview,
512+
emptyList()
513+
)
514+
515+
request.contents shouldHaveSize 2
516+
val thoughtContent = request.contents[1]
517+
thoughtContent.role shouldBe "model"
518+
thoughtContent.parts!!.single().shouldBeInstanceOf<GooglePart.Text>()
519+
val textPart = thoughtContent.parts!!.single() as GooglePart.Text
520+
textPart.text shouldBe "Previous thought"
521+
textPart.thought shouldBe true
522+
textPart.thoughtSignature shouldBe "prev-sig"
523+
}
524+
525+
@Test
526+
fun `processGoogleCandidate creates Reasoning for InlineData with signature`() {
527+
val client = GoogleLLMClient(apiKey = "test")
528+
val candidate = GoogleCandidate(
529+
content = GoogleContent(
530+
role = "model",
531+
parts = listOf(
532+
GooglePart.InlineData(
533+
inlineData = GoogleData.Blob("image/png", "png-bytes".encodeToByteArray()),
534+
thoughtSignature = "image-sig"
535+
)
536+
)
537+
),
538+
finishReason = "STOP"
539+
)
540+
541+
val responses = client.processGoogleCandidate(candidate, ResponseMetaInfo.Empty)
542+
543+
responses shouldHaveSize 2
544+
responses[0].shouldBeInstanceOf<Message.Reasoning>()
545+
(responses[0] as Message.Reasoning).encrypted shouldBe "image-sig"
546+
547+
responses[1].shouldBeInstanceOf<Message.Assistant>()
548+
val filePart = (responses[1] as Message.Assistant).parts.single() as ContentPart.Image
549+
filePart.format shouldBe "png"
550+
}
387551
}

0 commit comments

Comments
 (0)