1+ package io.github.hosseinkarami_dev.near.rpc.generator
2+
13import com.squareup.kotlinpoet.*
24import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
35import io.github.hosseinkarami_dev.near.rpc.generator.SealedInfo
@@ -11,15 +13,14 @@ object SerializerGenerator {
1113 fun generateFromSealedInfos (
1214 sealedInfos : List <SealedInfo >,
1315 serializerPackage : String ,
14- output : File ,
15- discriminatorFields : List <String > = listOf("type", "name")
16+ output : File
1617 ) {
1718 if (! output.exists()) output.mkdirs()
1819 if (sealedInfos.isEmpty()) return
1920
2021 for (info in sealedInfos) {
2122 try {
22- val fileSpec = generateSealedClassSerializer(info, serializerPackage, discriminatorFields )
23+ val fileSpec = generateSealedClassSerializer(info, serializerPackage)
2324 fileSpec.writeTo(output)
2425 } catch (ex: Exception ) {
2526 System .err.println (" Failed generating serializer for ${info.className} : ${ex.message} " )
@@ -32,13 +33,33 @@ object SerializerGenerator {
3233
3334 fun generateSealedClassSerializer (
3435 info : SealedInfo ,
35- serializerPackage : String ,
36- discriminatorFields : List <String > = listOf("type", "name")
36+ serializerPackage : String
3737 ): FileSpec {
3838 val modelsPkg = info.packageName
3939 val clsName = info.className
4040 val serializerName = " ${clsName} Serializer"
4141
42+ // derive discriminator candidate field names from the sealed-info itself
43+ val discCandidates: List <String > = run {
44+ val freq = mutableMapOf<String , Int >()
45+ val nVariants = info.variants.size
46+ for (v in info.variants) {
47+ for (p in v.props) {
48+ val t = sanitizeType(p.type)
49+ // consider only string-like properties as possible discriminators
50+ if (t.equals(" String" , ignoreCase = true ) || t.equals(" kotlin.String" , ignoreCase = true )) {
51+ freq[p.serialName] = (freq[p.serialName] ? : 0 ) + 1
52+ }
53+ }
54+ }
55+ if (freq.isEmpty()) emptyList()
56+ else {
57+ // threshold: appear in at least half of variants
58+ val threshold = (nVariants + 1 ) / 2
59+ freq.filter { it.value >= threshold }.keys.toList()
60+ }
61+ }
62+
4263 val modelClass = ClassName (modelsPkg, clsName)
4364 val kSerializerOfModel =
4465 ClassName (" kotlinx.serialization" , " KSerializer" ).parameterizedBy(modelClass)
@@ -153,9 +174,26 @@ object SerializerGenerator {
153174 val objBuilder = TypeSpec .objectBuilder(serializerName)
154175 .addSuperinterface(kSerializerOfModel)
155176
177+ // --- build a descriptor with one element per variant (names = serialName)
178+ val descriptorInitializer = CodeBlock .builder()
179+ descriptorInitializer.add(
180+ " %M(%S) {\n " ,
181+ MemberName (" kotlinx.serialization.descriptors" , " buildClassSerialDescriptor" ),
182+ " $modelsPkg .$clsName "
183+ )
184+ for (v in info.variants) {
185+ val variantClass = ClassName (modelsPkg, clsName, v.name)
186+ // produce: element("VariantSerialName", serializer<modelsPkg.ClsName.Variant>().descriptor)
187+ descriptorInitializer.add(
188+ " element(%S, %L)\n " ,
189+ v.serialName,
190+ CodeBlock .of(" serializer<%T>().descriptor" , ClassName (" kotlinx.serialization.json" ," JsonElement" ))
191+ )
192+ }
193+ descriptorInitializer.add(" }" )
156194 val descriptorProp = PropertySpec .builder(" descriptor" , serialDescriptor)
157195 .addModifiers(KModifier .OVERRIDE )
158- .initializer(" %M(%S) " , MemberName ( " kotlinx.serialization.descriptors " , " buildClassSerialDescriptor " ), " $modelsPkg . $clsName " )
196+ .initializer(descriptorInitializer.build() )
159197 .build()
160198 objBuilder.addProperty(descriptorProp)
161199
@@ -231,23 +269,14 @@ object SerializerGenerator {
231269 scb.addStatement(" return" )
232270 scb.endControlFlow()
233271
272+ // non-JSON: encode full variant serializer for each variant (descriptor elements align with variant order)
234273 scb.addStatement(" val out = encoder.beginStructure(descriptor)" )
235274 scb.beginControlFlow(" when (value)" )
236275 var idx = 0
237276 for (v in info.variants) {
238- if (v.kind == VariantInfo .Kind .OBJECT ) {
239- scb.addStatement(" %T.%L -> out.encodeStringElement(descriptor, %L, %S)" , modelClass, v.name, idx, v.serialName)
240- } else {
241- val variantClass = ClassName (modelsPkg, clsName, v.name)
242- val p = v.props.firstOrNull()
243- if (p != null ) {
244- val ser = serializerExpressionFor(p.type, v.name)
245- scb.addStatement(" is %T -> out.encodeSerializableElement(descriptor, %L, %L, value.%L)" , variantClass, idx, ser, p.name)
246- } else {
247- val varSer = CodeBlock .of(" serializer<%T>()" , variantClass)
248- scb.addStatement(" is %T -> out.encodeSerializableElement(descriptor, %L, %L, value)" , variantClass, idx, varSer)
249- }
250- }
277+ val variantClass = ClassName (modelsPkg, clsName, v.name)
278+ val varSer = CodeBlock .of(" serializer<%T>()" , variantClass)
279+ scb.addStatement(" is %T -> out.encodeSerializableElement(descriptor, %L, %L, value)" , variantClass, idx, varSer)
251280 idx++
252281 }
253282 scb.endControlFlow()
@@ -302,7 +331,7 @@ object SerializerGenerator {
302331 dcb.beginControlFlow(" is %T ->" , ClassName (" kotlinx.serialization.json" , " JsonObject" ))
303332 dcb.addStatement(" val jobj = element" )
304333
305- // ---------- new: field-based detection ----------
334+ // ---------- new: field-based detection with grouping to avoid duplicate checks ----------
306335 if (fieldBased) {
307336 dcb.addStatement(" // fieldBased union: detect variant by unique field presence" )
308337 for (v in dataVariants) {
@@ -333,6 +362,64 @@ object SerializerGenerator {
333362 }
334363 }
335364
365+ // --- Group variants by their required (non-nullable) keys to avoid emitting duplicated identical checks ---
366+ run {
367+ // build groups: Map(sortedRequiredKeysList -> List<VariantInfo>)
368+ val reqGroups = mutableMapOf<List <String >, MutableList <VariantInfo >>()
369+ for (v in dataVariants) {
370+ val reqKeys = v.props.filter { ! it.type.trim().endsWith(" ?" ) }.map { it.serialName }
371+ if (reqKeys.isNotEmpty()) {
372+ val sortedKey = reqKeys.sorted()
373+ reqGroups.computeIfAbsent(sortedKey) { mutableListOf () }.add(v)
374+ }
375+ }
376+
377+ for ((reqKeys, variantsWithSameReq) in reqGroups) {
378+ if (reqKeys.isEmpty()) continue
379+ val reqListLiteral = reqKeys.joinToString(" , " ) { " \" $it \" " }
380+ if (variantsWithSameReq.size == 1 ) {
381+ val v = variantsWithSameReq[0 ]
382+ dcb.beginControlFlow(" if (listOf($reqListLiteral ).all { jobj[it] != null })" )
383+ if (v.props.size == 1 && v.props[0 ].name == " value" ) {
384+ val ser = serializerExpressionFor(v.props[0 ].type, v.name)
385+ dcb.addStatement(" return %T(decoder.json.decodeFromJsonElement(%L, jobj[%S]!!))" , ClassName (modelsPkg, clsName, v.name), ser, v.props[0 ].serialName)
386+ } else {
387+ val variantSerializerCb = CodeBlock .of(" serializer<%T>()" , ClassName (modelsPkg, clsName, v.name))
388+ dcb.addStatement(" return decoder.json.decodeFromJsonElement(%L, jobj)" , variantSerializerCb)
389+ }
390+ dcb.endControlFlow()
391+ } else {
392+ // ambiguous group: try to disambiguate by 'type' field if present in all variants of the group
393+ val allHaveTypeField = variantsWithSameReq.all { vv -> vv.props.any { p -> p.serialName == " type" } }
394+ dcb.beginControlFlow(" if (listOf($reqListLiteral ).all { jobj[it] != null })" )
395+ if (allHaveTypeField) {
396+ dcb.addStatement(" val tfElem = jobj[%S]" , " type" )
397+ dcb.beginControlFlow(" if (tfElem is %T)" , ClassName (" kotlinx.serialization.json" , " JsonPrimitive" ))
398+ dcb.addStatement(" val tfVal = tfElem.content" )
399+ dcb.beginControlFlow(" when (tfVal)" )
400+ for (v in variantsWithSameReq) {
401+ dcb.beginControlFlow(" %S ->" , v.serialName)
402+ if (v.props.size == 1 && v.props[0 ].name == " value" ) {
403+ val ser = serializerExpressionFor(v.props[0 ].type, v.name)
404+ dcb.addStatement(" return %T(decoder.json.decodeFromJsonElement(%L, jobj[%S]!!))" , ClassName (modelsPkg, clsName, v.name), ser, v.props[0 ].serialName)
405+ } else {
406+ val variantSerializerCb = CodeBlock .of(" serializer<%T>()" , ClassName (modelsPkg, clsName, v.name))
407+ dcb.addStatement(" return decoder.json.decodeFromJsonElement(%L, jobj)" , variantSerializerCb)
408+ }
409+ dcb.endControlFlow()
410+ }
411+ dcb.addStatement(" else -> { /* not recognized by type field, fallthrough */ }" )
412+ dcb.endControlFlow() // end when(tfVal)
413+ dcb.endControlFlow() // end if (tfElem is JsonPrimitive)
414+ } else {
415+ // can't disambiguate here; allow later heuristics (wrapper/flat/heuristic) to handle these cases.
416+ dcb.addStatement(" // ambiguous required-keys group; skipping disambiguation here to avoid wrong decode" )
417+ }
418+ dcb.endControlFlow() // end if listOf(...).all
419+ }
420+ }
421+ }
422+
336423 // wrapper-style with single-key
337424 dcb.beginControlFlow(" if (jobj.size == 1)" )
338425 dcb.addStatement(" val entry = jobj.entries.first()" )
@@ -377,34 +464,37 @@ object SerializerGenerator {
377464 dcb.endControlFlow() // end when(key)
378465 dcb.endControlFlow() // end if (jobj.size == 1)
379466
380- // flat-style: try configured discriminators first, then heuristic fallback
467+ // flat-style: try configured discriminators first (derived from sealed info) , then heuristic fallback
381468 dcb.beginControlFlow(" else" )
382469
383- // inject discriminator candidates (from generator param)
384- val discListLiteral = discriminatorFields.joinToString(" , " ) { " \" $it \" " }
385- dcb.addStatement(" val discriminatorCandidates = listOf($discListLiteral )" )
386-
387470 dcb.addStatement(" var typeField: String? = null" )
388- // try configured candidates
389- dcb.beginControlFlow(" for (cand in discriminatorCandidates)" )
390- dcb.addStatement(" typeField = jobj[cand]?.jsonPrimitive?.contentOrNull" )
391- dcb.addStatement(" if (typeField != null) break" )
392- dcb.endControlFlow()
471+ if (discCandidates.isNotEmpty()) {
472+ val discListLiteral = discCandidates.joinToString(" , " ) { " \" $it \" " }
473+ dcb.addStatement(" val discriminatorCandidates = listOf($discListLiteral )" )
474+ dcb.beginControlFlow(" for (cand in discriminatorCandidates)" )
475+ dcb.addStatement(" val candElem = jobj[cand]" )
476+ dcb.beginControlFlow(" if (candElem is %T)" , ClassName (" kotlinx.serialization.json" , " JsonPrimitive" ))
477+ dcb.addStatement(" typeField = candElem.contentOrNull" )
478+ dcb.addStatement(" if (typeField != null) break" )
479+ dcb.endControlFlow()
480+ dcb.endControlFlow()
481+ }
393482
394483 // heuristic: if still null, look for any string value matching a known variant serialName
395484 val variantNamesList = info.variants.joinToString(" , " ) { " \" ${it.serialName} \" " }
396485 dcb.addStatement(" if (typeField == null) {" )
397486 dcb.addStatement(" val knownVariantNames = setOf($variantNamesList )" )
398487 dcb.addStatement(" for ((k, v) in jobj.entries) {" )
399- dcb.beginControlFlow(" if (v is %T && v.jsonPrimitive. isString)" , ClassName (" kotlinx.serialization.json" , " JsonElement " ))
400- dcb.addStatement(" val s = (v as %T).jsonPrimitive. content" , ClassName ( " kotlinx.serialization.json " , " JsonElement " ) )
488+ dcb.beginControlFlow(" if (v is %T && v.isString)" , ClassName (" kotlinx.serialization.json" , " JsonPrimitive " ))
489+ dcb.addStatement(" val s = v. content" )
401490 dcb.addStatement(" if (knownVariantNames.any { it.equals(s, ignoreCase = true) }) { typeField = s; break }" )
402491 dcb.endControlFlow()
403492 dcb.addStatement(" }" )
404493 dcb.addStatement(" }" )
405494
406495 // still null -> error
407- dcb.addStatement(" if (typeField == null) throw %T(%S)" , SerializationException ::class , " Missing discriminator (one of ${discriminatorFields.joinToString(" /" )} ) or recognizable variant in $clsName " )
496+ val discMsg = if (discCandidates.isNotEmpty()) " Missing discriminator (one of ${discCandidates.joinToString(" /" )} ) or recognizable variant in $clsName " else " Missing discriminator or recognizable variant in $clsName "
497+ dcb.addStatement(" if (typeField == null) throw %T(%S)" , SerializationException ::class , discMsg)
408498
409499 // normalize typeField for safe matching
410500 dcb.addStatement(" val tf = typeField.trim()" )
@@ -459,4 +549,4 @@ object SerializerGenerator {
459549 if (s.startsWith(" `" ) && s.endsWith(" `" ) && s.length > 1 ) s = s.substring(1 , s.length - 1 )
460550 return s.trim()
461551 }
462- }
552+ }
0 commit comments