@@ -7,6 +7,7 @@ import java.util.concurrent.ConcurrentHashMap
77
88import scala .annotation .switch
99import scala .collection .immutable .ListMap
10+ import scala .collection .mutable
1011import scala .util .control .NonFatal
1112
1213import zio .json .JsonCodec ._
@@ -784,99 +785,105 @@ object JsonCodec {
784785 }
785786
786787 private def enumDecoder [Z ](parentSchema : Schema .Enum [Z ]): ZJsonDecoder [Z ] = {
787- val cases = parentSchema.cases
788- val caseNameAliases = cases.flatMap {
789- case Schema .Case (name, _, _, _, _, annotations) =>
790- annotations.flatMap {
791- case a : caseNameAliases => a.aliases.map(_ -> name)
792- case cn : caseName => List (cn.name -> name)
793- case _ => Nil
794- }
795- }.toMap
788+ val caseNameAliases = new mutable.HashMap [String , Schema .Case [Z , Any ]]
789+ parentSchema.cases.foreach { case_ =>
790+ val schema = case_.asInstanceOf [Schema .Case [Z , Any ]]
791+ caseNameAliases.put(case_.id, schema)
792+ case_.annotations.foreach {
793+ case a : caseNameAliases => a.aliases.foreach(a => caseNameAliases.put(a, schema))
794+ case cn : caseName => caseNameAliases.put(cn.name, schema)
795+ case _ =>
796+ }
797+ }
796798
797799 def error (msg : String , trace : List [JsonError ]): Nothing =
798800 throw UnsafeJson (JsonError .Message (msg) :: trace)
799801
800- if (cases.forall(_.schema.isInstanceOf [Schema .CaseClass0 [_]])) { // if all cases are CaseClass0, decode as String
801- val caseMap = cases.map { case_ =>
802- case_.id -> case_.schema.asInstanceOf [Schema .CaseClass0 [Z ]].defaultConstruct()
803- }.toMap
804- ZJsonDecoder .string.mapOrFail(
805- s =>
806- caseMap.get(caseNameAliases.getOrElse(s, s)) match {
807- case Some (z) => Right (z)
808- case _ => Left (" unrecognized string" )
809- }
810- )
811- } else {
812- if (parentSchema.annotations.exists(_.isInstanceOf [noDiscriminator])) {
813- new ZJsonDecoder [Z ] {
814- private [this ] val decoders = cases.map(c => schemaDecoder(c.schema))
815-
816- override def unsafeDecode (trace : List [JsonError ], in : RetractReader ): Z = {
817- var rr = RecordingReader (in)
818- val it = decoders.iterator
819- while (it.hasNext) {
820- try {
821- return it.next().unsafeDecode(trace, rr).asInstanceOf [Z ]
822- } catch {
823- case ex if NonFatal (ex) =>
824- rr.rewind()
825- rr = RecordingReader (rr)
826- }
802+ if (parentSchema.cases.forall(_.schema.isInstanceOf [Schema .CaseClass0 [_]])) { // if all cases are CaseClass0, decode as String
803+ new ZJsonDecoder [Z ] {
804+ private [this ] val cases = new util.HashMap [String , Z ](caseNameAliases.size * 2 )
805+
806+ caseNameAliases.foreach {
807+ case (name, case_) =>
808+ cases.put(name, case_.schema.asInstanceOf [Schema .CaseClass0 [Z ]].defaultConstruct())
809+ }
810+
811+ override def unsafeDecode (trace : List [JsonError ], in : RetractReader ): Z = {
812+ val result = cases.get(Lexer .string(trace, in).toString)
813+ if (result == null ) error(" unrecognized string" , trace)
814+ result
815+ }
816+ }
817+ } else if (parentSchema.annotations.exists(_.isInstanceOf [noDiscriminator])) {
818+ new ZJsonDecoder [Z ] {
819+ private [this ] val decoders = parentSchema.cases.map(c => schemaDecoder(c.schema))
820+
821+ override def unsafeDecode (trace : List [JsonError ], in : RetractReader ): Z = {
822+ var rr = RecordingReader (in)
823+ val it = decoders.iterator
824+ while (it.hasNext) {
825+ try {
826+ return it.next().unsafeDecode(trace, rr).asInstanceOf [Z ]
827+ } catch {
828+ case ex if NonFatal (ex) =>
829+ rr.rewind()
830+ rr = RecordingReader (rr)
827831 }
828- error(" none of the subtypes could decode the data" , trace)
829832 }
833+ error(" none of the subtypes could decode the data" , trace)
830834 }
831- } else {
832- parentSchema.annotations.collectFirst { case d : discriminatorName => d.tag } match {
833- case None =>
834- val decoderMap = cases.map { case_ =>
835- case_.id -> schemaDecoder(case_.schema).asInstanceOf [ZJsonDecoder [Any ]]
836- }.toMap
837- (trace : List [JsonError ], in : RetractReader ) => {
838- Lexer .char(trace, in, '{' )
839- if (! Lexer .firstField(trace, in)) error(" missing subtype" , trace)
840- val fieldName = Lexer .string(trace, in).toString
841- val subtype = caseNameAliases.getOrElse(fieldName, fieldName)
842- val trace_ = JsonError .ObjectAccess (subtype) :: trace
843- Lexer .char(trace_, in, ':' )
844- val decoded = decoderMap
845- .getOrElse(subtype, error(" unrecognized subtype" , trace_))
846- .unsafeDecode(trace_, in)
847- .asInstanceOf [Z ]
848- Lexer .nextField(trace_, in)
849- decoded
850- }
851- case Some (discriminatorName) =>
852- val caseMap = cases.map { case_ =>
853- case_.id -> case_.schema.asInstanceOf [Schema [Any ]]
854- }.toMap
855- (trace : List [JsonError ], in : RetractReader ) => {
856- Lexer .char(trace, in, '{' )
857- if (! Lexer .firstField(trace, in)) error(" missing subtype" , trace)
858- val rr = RecordingReader (in)
859- var index = 0
860- while ({
861- (Lexer .string(trace, rr).toString != discriminatorName) && {
862- Lexer .char(trace, rr, ':' )
863- Lexer .skipValue(trace, rr)
864- Lexer .nextField(trace, rr) || error(" missing subtype" , trace)
865- }
866- }) {
867- index += 1
835+ }
836+ } else {
837+ parentSchema.annotations.collectFirst { case d : discriminatorName => d.tag } match {
838+ case None =>
839+ val cases = new util.HashMap [String , (JsonError .ObjectAccess , ZJsonDecoder [Any ])](caseNameAliases.size * 2 )
840+ caseNameAliases.foreach {
841+ case (name, case_) =>
842+ cases.put(name, (JsonError .ObjectAccess (case_.id), schemaDecoder(case_.schema)))
843+ }
844+ (trace : List [JsonError ], in : RetractReader ) => {
845+ Lexer .char(trace, in, '{' )
846+ if (! Lexer .firstField(trace, in)) error(" missing subtype" , trace)
847+ val fieldName = Lexer .string(trace, in).toString
848+ val spanWithDecoder = cases.get(fieldName)
849+ if (spanWithDecoder eq null ) error(" unrecognized subtype" , trace)
850+ val trace_ = spanWithDecoder._1 :: trace
851+ Lexer .char(trace_, in, ':' )
852+ val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf [Z ]
853+ Lexer .nextField(trace_, in)
854+ decoded
855+ }
856+ case Some (discriminatorName) =>
857+ val cases = new util.HashMap [String , (JsonError .ObjectAccess , Schema [Any ])](caseNameAliases.size * 2 )
858+ caseNameAliases.foreach {
859+ case (name, case_) =>
860+ cases.put(name, (JsonError .ObjectAccess (case_.id), case_.schema))
861+ }
862+ val discriminatorSpan = JsonError .ObjectAccess (discriminatorName)
863+ (trace : List [JsonError ], in : RetractReader ) => {
864+ Lexer .char(trace, in, '{' )
865+ if (! Lexer .firstField(trace, in)) error(" missing subtype" , trace)
866+ val rr = RecordingReader (in)
867+ var index = 0
868+ while ({
869+ (Lexer .string(trace, rr).toString != discriminatorName) && {
870+ Lexer .char(trace, rr, ':' )
871+ Lexer .skipValue(trace, rr)
872+ Lexer .nextField(trace, rr) || error(" missing subtype" , trace)
868873 }
869- val trace_ = JsonError .ObjectAccess (discriminatorName) :: trace
870- Lexer .char(trace_, rr, ':' )
871- val fieldValue = Lexer .string(trace_, rr).toString
872- val subtype = caseNameAliases.getOrElse(fieldValue, fieldValue)
873- rr.rewind()
874- val schema = caseMap.getOrElse(subtype, error(" unrecognized subtype" , trace_))
875- schemaDecoder(schema, index)
876- .unsafeDecode(JsonError .ObjectAccess (subtype) :: trace_, rr)
877- .asInstanceOf [Z ]
874+ }) {
875+ index += 1
878876 }
879- }
877+ val trace_ = discriminatorSpan :: trace
878+ Lexer .char(trace_, rr, ':' )
879+ val fieldValue = Lexer .string(trace_, rr).toString
880+ rr.rewind()
881+ val spanWithSchema = cases.get(fieldValue)
882+ if (spanWithSchema eq null ) error(" unrecognized subtype" , trace_)
883+ schemaDecoder(spanWithSchema._2, index)
884+ .unsafeDecode(spanWithSchema._1 :: trace_, rr)
885+ .asInstanceOf [Z ]
886+ }
880887 }
881888 }
882889 }
0 commit comments