Skip to content

Commit 0054d4a

Browse files
mltheuserMalte Heuser
andauthored
KG-549 Add handler for JsonElement in JsonSchemaGenerator (#1140)
Related to: [KG-549](https://youtrack.jetbrains.com/issue/KG-549) ## Motivation and Context The `StandardJsonSchemaGenerator` fails with an `IllegalArgumentException` when attempting to generate a schema for any serializable class containing a `JsonElement` property. This occurs because the `processPolymorphic` method relies on an internal structural assumption about `kotlinx.serialization`. For a typical user-defined `sealed class`, the library generates a wrapper object in JSON containing a discriminator (e.g., `"type": "SubclassName"`) and a content field (e.g., `"value": {...}`). The `SerialDescriptor` for such a class mirrors this structure, presenting its elements as a `(type, value)` pair. The generator's current implementation hardcodes a check for this second `"value"` element to find the subtype definitions. However, `JsonElement` is a special-cased, primitive polymorphic type. It is not serialized within a wrapper; it serializes directly to its raw JSON representation. Consequently, its `SerialDescriptor` does not have the `(type, value)` structure. Instead, its elements are a direct list of its possible subtypes (`JsonObject`, `JsonArray`, etc.). This structural mismatch causes the generator's hardcoded check to fail, resulting in the `IllegalArgumentException`. This change introduces a special case to detect `JsonElement` and generates a permissive, empty schema (`{}`), which is the correct JSON Schema representation for "any valid JSON value". The fix is implemented in a type-safe manner and correctly handles both nullable (`JsonElement?`) and non-nullable properties. ## Breaking Changes None. This is a non-breaking bug fix that enables previously unsupported functionality. --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [x] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [x] An issue describing the proposed change exists - [x] The pull request includes a link to the issue - [x] The change was discussed and approved in the issue - [ ] Docs have been added / updated --------- Co-authored-by: Malte Heuser <[email protected]>
1 parent 0d7179c commit 0054d4a

File tree

5 files changed

+172
-2
lines changed

5 files changed

+172
-2
lines changed

prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/json/generator/GenericJsonSchemaGenerator.kt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
package ai.koog.prompt.structure.json.generator
22

3+
import kotlinx.serialization.builtins.nullable
34
import kotlinx.serialization.descriptors.PolymorphicKind
45
import kotlinx.serialization.descriptors.PrimitiveKind
56
import kotlinx.serialization.descriptors.SerialKind
67
import kotlinx.serialization.descriptors.StructureKind
78
import kotlinx.serialization.descriptors.elementNames
89
import kotlinx.serialization.json.JsonArray
10+
import kotlinx.serialization.json.JsonElement
911
import kotlinx.serialization.json.JsonObject
1012
import kotlinx.serialization.json.JsonObjectBuilder
1113
import kotlinx.serialization.json.JsonPrimitive
1214
import kotlinx.serialization.json.add
1315
import kotlinx.serialization.json.buildJsonArray
1416
import kotlinx.serialization.json.buildJsonObject
1517
import kotlinx.serialization.json.put
18+
import kotlinx.serialization.serializer
1619

1720
/**
1821
* Generic extensions of [JsonSchemaGenerator] that provides some common base implementations of visit methods.
@@ -51,8 +54,13 @@ public abstract class GenericJsonSchemaGenerator : JsonSchemaGenerator() {
5154
StructureKind.CLASS, StructureKind.OBJECT ->
5255
processObject(context)
5356

54-
is PolymorphicKind ->
55-
processPolymorphic(context)
57+
is PolymorphicKind -> {
58+
if (context.descriptor.serialName in JSON_ELEMENT_SERIAL_NAMES) {
59+
processJsonElement(context)
60+
} else {
61+
processPolymorphic(context)
62+
}
63+
}
5664

5765
else ->
5866
throw IllegalArgumentException("Encountered unsupported type while generating JSON schema: ${context.descriptor.kind}")
@@ -211,4 +219,16 @@ public abstract class GenericJsonSchemaGenerator : JsonSchemaGenerator() {
211219
override fun processPolymorphic(context: GenerationContext): JsonObject {
212220
throw UnsupportedOperationException("Polymorphic types are not supported by ${this::class.simpleName} generator")
213221
}
222+
223+
override fun processJsonElement(context: GenerationContext): JsonObject {
224+
throw UnsupportedOperationException("JsonElement is not supported by ${this::class.simpleName} generator")
225+
}
226+
}
227+
228+
private val JSON_ELEMENT_SERIAL_NAMES: Set<String> by lazy {
229+
val elementSerializer = serializer<JsonElement>()
230+
setOf(
231+
elementSerializer.descriptor.serialName,
232+
elementSerializer.nullable.descriptor.serialName
233+
)
214234
}

prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/json/generator/JsonSchemaGenerator.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ public abstract class JsonSchemaGenerator {
123123
protected abstract fun processObject(context: GenerationContext): JsonObject
124124
protected abstract fun processPolymorphic(context: GenerationContext): JsonObject
125125
protected abstract fun processClassDiscriminator(context: GenerationContext): JsonObject
126+
protected abstract fun processJsonElement(context: GenerationContext): JsonObject
126127
}
127128

128129
/**

prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/json/generator/StandardJsonSchemaGenerator.kt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,24 @@ public open class StandardJsonSchemaGenerator : GenericJsonSchemaGenerator() {
264264
override fun processClassDiscriminator(context: GenerationContext): JsonObject = buildJsonObject {
265265
put(JsonSchemaConsts.Keys.CONST, context.descriptor.serialName)
266266
}
267+
268+
override fun processJsonElement(context: GenerationContext): JsonObject {
269+
return if (context.descriptor.isNullable) {
270+
buildJsonObject {
271+
put(
272+
JsonSchemaConsts.Keys.ONE_OF,
273+
buildJsonArray {
274+
add(buildJsonObject { /* empty schema for "any type" */ })
275+
add(buildJsonObject { put(JsonSchemaConsts.Keys.TYPE, JsonSchemaConsts.Types.NULL) })
276+
}
277+
)
278+
putDescription(context.currentDescription)
279+
}
280+
} else {
281+
buildJsonObject {
282+
// Empty object represents "any type" in JSON schema
283+
putDescription(context.currentDescription)
284+
}
285+
}
286+
}
267287
}

prompt/prompt-structure/src/commonTest/kotlin/ai/koog/prompt/structure/StructureFixingParserTest.kt

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@ import ai.koog.prompt.structure.json.JsonStructure
66
import kotlinx.coroutines.test.runTest
77
import kotlinx.serialization.Serializable
88
import kotlinx.serialization.json.Json
9+
import kotlinx.serialization.json.JsonElement
10+
import kotlinx.serialization.json.JsonObject
11+
import kotlinx.serialization.json.JsonPrimitive
912
import kotlin.test.Test
1013
import kotlin.test.assertEquals
1114
import kotlin.test.assertFailsWith
15+
import kotlin.test.assertTrue
1216

1317
class StructureFixingParserTest {
1418
@Serializable
@@ -17,6 +21,12 @@ class StructureFixingParserTest {
1721
val b: Int,
1822
)
1923

24+
@Serializable
25+
private data class DataWithWildcard(
26+
val id: String,
27+
val payload: JsonElement
28+
)
29+
2030
private val testData = TestData("test", 42)
2131
private val testDataJson = Json.encodeToString(testData)
2232
private val testStructure = JsonStructure.create<TestData>()
@@ -75,4 +85,46 @@ class StructureFixingParserTest {
7585
parser.parse(mockExecutor, testStructure, invalidContent)
7686
}
7787
}
88+
89+
@Test
90+
fun testFixInvalidJsonElementContent() = runTest {
91+
val parser = StructureFixingParser(
92+
model = OpenAIModels.Chat.GPT4oMini,
93+
retries = 2,
94+
)
95+
96+
val structure = JsonStructure.create<DataWithWildcard>()
97+
98+
val invalidContent = """
99+
{
100+
"id": "test-id",
101+
"payload": {
102+
unquotedKey: "someValue",
103+
brokenArray: [1, 2
104+
}
105+
}
106+
""".trimIndent()
107+
108+
val fixedContent = """
109+
{
110+
"id": "test-id",
111+
"payload": {
112+
"unquotedKey": "someValue",
113+
"brokenArray": [1, 2]
114+
}
115+
}
116+
""".trimIndent()
117+
118+
val mockExecutor = getMockExecutor {
119+
mockLLMAnswer(fixedContent) onRequestContains "unquotedKey"
120+
}
121+
122+
val result = parser.parse(mockExecutor, structure, invalidContent)
123+
124+
assertEquals("test-id", result.id)
125+
assertTrue(result.payload is JsonObject)
126+
127+
val payloadObj = result.payload
128+
assertEquals(JsonPrimitive("someValue"), payloadObj["unquotedKey"])
129+
}
78130
}

prompt/prompt-structure/src/commonTest/kotlin/ai/koog/prompt/structure/json/generator/JsonSchemaGeneratorTest.kt

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import kotlinx.serialization.SerialName
55
import kotlinx.serialization.Serializable
66
import kotlinx.serialization.json.ClassDiscriminatorMode
77
import kotlinx.serialization.json.Json
8+
import kotlinx.serialization.json.JsonElement
89
import kotlinx.serialization.modules.SerializersModule
910
import kotlinx.serialization.modules.polymorphic
1011
import kotlinx.serialization.serializer
@@ -162,6 +163,20 @@ class JsonSchemaGeneratorTest {
162163
val recursiveProperty: RecursiveTestClass?
163164
)
164165

166+
@Serializable
167+
@SerialName("EventData")
168+
data class EventData(
169+
@property:LLMDescription("Any valid JSON value.")
170+
val value: JsonElement
171+
)
172+
173+
@Serializable
174+
@SerialName("NullableEventData")
175+
data class NullableEventData(
176+
@property:LLMDescription("Any valid JSON value, or null.")
177+
val value: JsonElement? = null
178+
)
179+
165180
@Test
166181
fun testGenerateStandardSchema() {
167182
val result = standardGenerator.generate(json, "TestClass", serializer<TestClass>(), emptyMap())
@@ -805,4 +820,66 @@ class JsonSchemaGeneratorTest {
805820
basicGenerator.generate(json, "RecursiveTestClass", serializer<RecursiveTestClass>(), emptyMap())
806821
}
807822
}
823+
824+
@Test
825+
fun testStandardSchemaWithJsonElementProperty() {
826+
val result = standardGenerator.generate(json, "EventData", serializer<EventData>(), emptyMap())
827+
val schema = json.encodeToString(result.schema)
828+
829+
val expectedSchema = """
830+
{
831+
"${"$"}id": "EventData",
832+
"${"$"}defs": {
833+
"EventData": {
834+
"type": "object",
835+
"properties": {
836+
"value": {
837+
"description": "Any valid JSON value."
838+
}
839+
},
840+
"required": [
841+
"value"
842+
],
843+
"additionalProperties": false
844+
}
845+
},
846+
"${"$"}ref": "#/${"$"}defs/EventData"
847+
}
848+
""".trimIndent()
849+
850+
assertEquals(expectedSchema, schema)
851+
}
852+
853+
@Test
854+
fun testStandardSchemaWithNullableJsonElementProperty() {
855+
val result = standardGenerator.generate(json, "NullableEventData", serializer<NullableEventData>(), emptyMap())
856+
val schema = json.encodeToString(result.schema)
857+
858+
val expectedSchema = """
859+
{
860+
"${"$"}id": "NullableEventData",
861+
"${"$"}defs": {
862+
"NullableEventData": {
863+
"type": "object",
864+
"properties": {
865+
"value": {
866+
"oneOf": [
867+
{},
868+
{
869+
"type": "null"
870+
}
871+
],
872+
"description": "Any valid JSON value, or null."
873+
}
874+
},
875+
"required": [],
876+
"additionalProperties": false
877+
}
878+
},
879+
"${"$"}ref": "#/${"$"}defs/NullableEventData"
880+
}
881+
""".trimIndent()
882+
883+
assertEquals(expectedSchema, schema)
884+
}
808885
}

0 commit comments

Comments
 (0)