Skip to content

Commit bb476c4

Browse files
committed
Lesser allocations and hash map lookups when parsing enums
1 parent 9f1919f commit bb476c4

File tree

1 file changed

+91
-84
lines changed

1 file changed

+91
-84
lines changed

zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala

Lines changed: 91 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import java.util.concurrent.ConcurrentHashMap
77

88
import scala.annotation.switch
99
import scala.collection.immutable.ListMap
10+
import scala.collection.mutable
1011
import scala.util.control.NonFatal
1112

1213
import zio.json.JsonCodec._
@@ -784,99 +785,105 @@ object JsonCodec {
784785
}
785786

786787
private def enumDecoder[Z](parentSchema: Schema.Enum[Z]): ZJsonDecoder[Z] = {
787-
val cases = parentSchema.cases
788-
val caseNameAliases = cases.flatMap {
789-
case Schema.Case(name, _, _, _, _, annotations) =>
790-
annotations.flatMap {
791-
case a: caseNameAliases => a.aliases.map(_ -> name)
792-
case cn: caseName => List(cn.name -> name)
793-
case _ => Nil
794-
}
795-
}.toMap
788+
val caseNameAliases = new mutable.HashMap[String, Schema.Case[Z, Any]]
789+
parentSchema.cases.foreach { case_ =>
790+
val schema = case_.asInstanceOf[Schema.Case[Z, Any]]
791+
caseNameAliases.put(case_.id, schema)
792+
case_.annotations.foreach {
793+
case a: caseNameAliases => a.aliases.foreach(a => caseNameAliases.put(a, schema))
794+
case cn: caseName => caseNameAliases.put(cn.name, schema)
795+
case _ =>
796+
}
797+
}
796798

797799
def error(msg: String, trace: List[JsonError]): Nothing =
798800
throw UnsafeJson(JsonError.Message(msg) :: trace)
799801

800-
if (cases.forall(_.schema.isInstanceOf[Schema.CaseClass0[_]])) { // if all cases are CaseClass0, decode as String
801-
val caseMap = cases.map { case_ =>
802-
case_.id -> case_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct()
803-
}.toMap
804-
ZJsonDecoder.string.mapOrFail(
805-
s =>
806-
caseMap.get(caseNameAliases.getOrElse(s, s)) match {
807-
case Some(z) => Right(z)
808-
case _ => Left("unrecognized string")
809-
}
810-
)
811-
} else {
812-
if (parentSchema.annotations.exists(_.isInstanceOf[noDiscriminator])) {
813-
new ZJsonDecoder[Z] {
814-
private[this] val decoders = cases.map(c => schemaDecoder(c.schema))
815-
816-
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
817-
var rr = RecordingReader(in)
818-
val it = decoders.iterator
819-
while (it.hasNext) {
820-
try {
821-
return it.next().unsafeDecode(trace, rr).asInstanceOf[Z]
822-
} catch {
823-
case ex if NonFatal(ex) =>
824-
rr.rewind()
825-
rr = RecordingReader(rr)
826-
}
802+
if (parentSchema.cases.forall(_.schema.isInstanceOf[Schema.CaseClass0[_]])) { // if all cases are CaseClass0, decode as String
803+
new ZJsonDecoder[Z] {
804+
private[this] val cases = new util.HashMap[String, Z](caseNameAliases.size * 2)
805+
806+
caseNameAliases.foreach {
807+
case (name, case_) =>
808+
cases.put(name, case_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct())
809+
}
810+
811+
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
812+
val result = cases.get(Lexer.string(trace, in).toString)
813+
if (result == null) error("unrecognized string", trace)
814+
result
815+
}
816+
}
817+
} else if (parentSchema.annotations.exists(_.isInstanceOf[noDiscriminator])) {
818+
new ZJsonDecoder[Z] {
819+
private[this] val decoders = parentSchema.cases.map(c => schemaDecoder(c.schema))
820+
821+
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
822+
var rr = RecordingReader(in)
823+
val it = decoders.iterator
824+
while (it.hasNext) {
825+
try {
826+
return it.next().unsafeDecode(trace, rr).asInstanceOf[Z]
827+
} catch {
828+
case ex if NonFatal(ex) =>
829+
rr.rewind()
830+
rr = RecordingReader(rr)
827831
}
828-
error("none of the subtypes could decode the data", trace)
829832
}
833+
error("none of the subtypes could decode the data", trace)
830834
}
831-
} else {
832-
parentSchema.annotations.collectFirst { case d: discriminatorName => d.tag } match {
833-
case None =>
834-
val decoderMap = cases.map { case_ =>
835-
case_.id -> schemaDecoder(case_.schema).asInstanceOf[ZJsonDecoder[Any]]
836-
}.toMap
837-
(trace: List[JsonError], in: RetractReader) => {
838-
Lexer.char(trace, in, '{')
839-
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
840-
val fieldName = Lexer.string(trace, in).toString
841-
val subtype = caseNameAliases.getOrElse(fieldName, fieldName)
842-
val trace_ = JsonError.ObjectAccess(subtype) :: trace
843-
Lexer.char(trace_, in, ':')
844-
val decoded = decoderMap
845-
.getOrElse(subtype, error("unrecognized subtype", trace_))
846-
.unsafeDecode(trace_, in)
847-
.asInstanceOf[Z]
848-
Lexer.nextField(trace_, in)
849-
decoded
850-
}
851-
case Some(discriminatorName) =>
852-
val caseMap = cases.map { case_ =>
853-
case_.id -> case_.schema.asInstanceOf[Schema[Any]]
854-
}.toMap
855-
(trace: List[JsonError], in: RetractReader) => {
856-
Lexer.char(trace, in, '{')
857-
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
858-
val rr = RecordingReader(in)
859-
var index = 0
860-
while ({
861-
(Lexer.string(trace, rr).toString != discriminatorName) && {
862-
Lexer.char(trace, rr, ':')
863-
Lexer.skipValue(trace, rr)
864-
Lexer.nextField(trace, rr) || error("missing subtype", trace)
865-
}
866-
}) {
867-
index += 1
835+
}
836+
} else {
837+
parentSchema.annotations.collectFirst { case d: discriminatorName => d.tag } match {
838+
case None =>
839+
val cases = new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size * 2)
840+
caseNameAliases.foreach {
841+
case (name, case_) =>
842+
cases.put(name, (JsonError.ObjectAccess(case_.id), schemaDecoder(case_.schema)))
843+
}
844+
(trace: List[JsonError], in: RetractReader) => {
845+
Lexer.char(trace, in, '{')
846+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
847+
val fieldName = Lexer.string(trace, in).toString
848+
val spanWithDecoder = cases.get(fieldName)
849+
if (spanWithDecoder eq null) error("unrecognized subtype", trace)
850+
val trace_ = spanWithDecoder._1 :: trace
851+
Lexer.char(trace_, in, ':')
852+
val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf[Z]
853+
Lexer.nextField(trace_, in)
854+
decoded
855+
}
856+
case Some(discriminatorName) =>
857+
val cases = new util.HashMap[String, (JsonError.ObjectAccess, Schema[Any])](caseNameAliases.size * 2)
858+
caseNameAliases.foreach {
859+
case (name, case_) =>
860+
cases.put(name, (JsonError.ObjectAccess(case_.id), case_.schema))
861+
}
862+
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
863+
(trace: List[JsonError], in: RetractReader) => {
864+
Lexer.char(trace, in, '{')
865+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
866+
val rr = RecordingReader(in)
867+
var index = 0
868+
while ({
869+
(Lexer.string(trace, rr).toString != discriminatorName) && {
870+
Lexer.char(trace, rr, ':')
871+
Lexer.skipValue(trace, rr)
872+
Lexer.nextField(trace, rr) || error("missing subtype", trace)
868873
}
869-
val trace_ = JsonError.ObjectAccess(discriminatorName) :: trace
870-
Lexer.char(trace_, rr, ':')
871-
val fieldValue = Lexer.string(trace_, rr).toString
872-
val subtype = caseNameAliases.getOrElse(fieldValue, fieldValue)
873-
rr.rewind()
874-
val schema = caseMap.getOrElse(subtype, error("unrecognized subtype", trace_))
875-
schemaDecoder(schema, index)
876-
.unsafeDecode(JsonError.ObjectAccess(subtype) :: trace_, rr)
877-
.asInstanceOf[Z]
874+
}) {
875+
index += 1
878876
}
879-
}
877+
val trace_ = discriminatorSpan :: trace
878+
Lexer.char(trace_, rr, ':')
879+
val fieldValue = Lexer.string(trace_, rr).toString
880+
rr.rewind()
881+
val spanWithSchema = cases.get(fieldValue)
882+
if (spanWithSchema eq null) error("unrecognized subtype", trace_)
883+
schemaDecoder(spanWithSchema._2, index)
884+
.unsafeDecode(spanWithSchema._1 :: trace_, rr)
885+
.asInstanceOf[Z]
886+
}
880887
}
881888
}
882889
}

0 commit comments

Comments
 (0)