Skip to content

Commit 2932c6a

Browse files
committed
Use StringMatrix for enum decoding to reduce allocations and CPU usage
1 parent bb476c4 commit 2932c6a

File tree

1 file changed

+124
-58
lines changed

1 file changed

+124
-58
lines changed

zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala

Lines changed: 124 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,20 @@ import java.nio.CharBuffer
44
import java.nio.charset.StandardCharsets
55
import java.util
66
import java.util.concurrent.ConcurrentHashMap
7-
87
import scala.annotation.switch
98
import scala.collection.immutable.ListMap
109
import scala.collection.mutable
1110
import scala.util.control.NonFatal
12-
1311
import zio.json.JsonCodec._
1412
import zio.json.JsonDecoder.{ JsonError, UnsafeJson }
1513
import zio.json.ast.Json
1614
import zio.json.internal.{ Lexer, RecordingReader, RetractReader, StringMatrix, WithRecordingReader, Write }
1715
import zio.json.{
16+
JsonFieldDecoder,
17+
JsonFieldEncoder,
1818
JsonCodec => ZJsonCodec,
1919
JsonDecoder => ZJsonDecoder,
20-
JsonEncoder => ZJsonEncoder,
21-
JsonFieldDecoder,
22-
JsonFieldEncoder
20+
JsonEncoder => ZJsonEncoder
2321
}
2422
import zio.prelude.NonEmptyMap
2523
import zio.schema.Schema.GenericRecord
@@ -800,18 +798,33 @@ object JsonCodec {
800798
throw UnsafeJson(JsonError.Message(msg) :: trace)
801799

802800
if (parentSchema.cases.forall(_.schema.isInstanceOf[Schema.CaseClass0[_]])) { // if all cases are CaseClass0, decode as String
803-
new ZJsonDecoder[Z] {
804-
private[this] val cases = new util.HashMap[String, Z](caseNameAliases.size * 2)
805-
806-
caseNameAliases.foreach {
807-
case (name, case_) =>
808-
cases.put(name, case_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct())
801+
if (caseNameAliases.size <= 64) {
802+
new ZJsonDecoder[Z] {
803+
private[this] val stringMatrix = new StringMatrix(caseNameAliases.keys.toArray)
804+
private[this] val cases = caseNameAliases.values.map { case_ =>
805+
case_.schema.asInstanceOf[Schema.CaseClass0[Any]].defaultConstruct()
806+
}.toArray.asInstanceOf[Array[Z]]
807+
808+
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
809+
val idx = Lexer.enumeration(trace, in, stringMatrix)
810+
if (idx < 0) error("unrecognized string", trace)
811+
cases(idx)
812+
}
809813
}
814+
} else {
815+
new ZJsonDecoder[Z] {
816+
private[this] val cases = new util.HashMap[String, Z](caseNameAliases.size * 2)
810817

811-
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
812-
val result = cases.get(Lexer.string(trace, in).toString)
813-
if (result == null) error("unrecognized string", trace)
814-
result
818+
caseNameAliases.foreach {
819+
case (name, case_) =>
820+
cases.put(name, case_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct())
821+
}
822+
823+
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
824+
val result = cases.get(Lexer.string(trace, in).toString)
825+
if (result == null) error("unrecognized string", trace)
826+
result
827+
}
815828
}
816829
}
817830
} else if (parentSchema.annotations.exists(_.isInstanceOf[noDiscriminator])) {
@@ -836,53 +849,106 @@ object JsonCodec {
836849
} else {
837850
parentSchema.annotations.collectFirst { case d: discriminatorName => d.tag } match {
838851
case None =>
839-
val cases = new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size * 2)
840-
caseNameAliases.foreach {
841-
case (name, case_) =>
842-
cases.put(name, (JsonError.ObjectAccess(case_.id), schemaDecoder(case_.schema)))
843-
}
844-
(trace: List[JsonError], in: RetractReader) => {
845-
Lexer.char(trace, in, '{')
846-
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
847-
val fieldName = Lexer.string(trace, in).toString
848-
val spanWithDecoder = cases.get(fieldName)
849-
if (spanWithDecoder eq null) error("unrecognized subtype", trace)
850-
val trace_ = spanWithDecoder._1 :: trace
851-
Lexer.char(trace_, in, ':')
852-
val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf[Z]
853-
Lexer.nextField(trace_, in)
854-
decoded
852+
if (caseNameAliases.size <= 64) {
853+
val stringMatrix = new StringMatrix(caseNameAliases.keys.toArray)
854+
val cases = caseNameAliases.values.map { case_ =>
855+
(JsonError.ObjectAccess(case_.id), schemaDecoder(case_.schema))
856+
}.toArray
857+
(trace: List[JsonError], in: RetractReader) => {
858+
Lexer.char(trace, in, '{')
859+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
860+
val idx = Lexer.enumeration(trace, in, stringMatrix)
861+
if (idx < 0) error("unrecognized subtype", trace)
862+
val spanWithDecoder = cases(idx)
863+
val trace_ = spanWithDecoder._1 :: trace
864+
Lexer.char(trace_, in, ':')
865+
val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf[Z]
866+
Lexer.nextField(trace_, in)
867+
decoded
868+
}
869+
} else {
870+
val cases =
871+
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size * 2)
872+
caseNameAliases.foreach {
873+
case (name, case_) =>
874+
cases.put(name, (JsonError.ObjectAccess(case_.id), schemaDecoder(case_.schema)))
875+
}
876+
(trace: List[JsonError], in: RetractReader) => {
877+
Lexer.char(trace, in, '{')
878+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
879+
val fieldName = Lexer.string(trace, in).toString
880+
val spanWithDecoder = cases.get(fieldName)
881+
if (spanWithDecoder eq null) error("unrecognized subtype", trace)
882+
val trace_ = spanWithDecoder._1 :: trace
883+
Lexer.char(trace_, in, ':')
884+
val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf[Z]
885+
Lexer.nextField(trace_, in)
886+
decoded
887+
}
855888
}
856889
case Some(discriminatorName) =>
857-
val cases = new util.HashMap[String, (JsonError.ObjectAccess, Schema[Any])](caseNameAliases.size * 2)
858-
caseNameAliases.foreach {
859-
case (name, case_) =>
860-
cases.put(name, (JsonError.ObjectAccess(case_.id), case_.schema))
861-
}
862-
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
863-
(trace: List[JsonError], in: RetractReader) => {
864-
Lexer.char(trace, in, '{')
865-
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
866-
val rr = RecordingReader(in)
867-
var index = 0
868-
while ({
869-
(Lexer.string(trace, rr).toString != discriminatorName) && {
870-
Lexer.char(trace, rr, ':')
871-
Lexer.skipValue(trace, rr)
872-
Lexer.nextField(trace, rr) || error("missing subtype", trace)
890+
if (caseNameAliases.size <= 64) {
891+
val discriminatorMatrix = new StringMatrix(Array(discriminatorName))
892+
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
893+
val caseMatrix = new StringMatrix(caseNameAliases.keys.toArray)
894+
val cases = caseNameAliases.values
895+
.map(case_ => (JsonError.ObjectAccess(case_.id), case_.schema))
896+
.toArray
897+
(trace: List[JsonError], in: RetractReader) => {
898+
Lexer.char(trace, in, '{')
899+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
900+
val rr = RecordingReader(in)
901+
var index = 0
902+
while ({
903+
(Lexer.enumeration(trace, rr, discriminatorMatrix) < 0) && {
904+
Lexer.char(trace, rr, ':')
905+
Lexer.skipValue(trace, rr)
906+
Lexer.nextField(trace, rr) || error("missing subtype", trace)
907+
}
908+
}) {
909+
index += 1
873910
}
874-
}) {
875-
index += 1
911+
val trace_ = discriminatorSpan :: trace
912+
Lexer.char(trace_, rr, ':')
913+
val idx = Lexer.enumeration(trace_, rr, caseMatrix)
914+
rr.rewind()
915+
if (idx < 0) error("unrecognized subtype", trace_)
916+
val spanWithSchema = cases(idx)
917+
schemaDecoder(spanWithSchema._2, index)
918+
.unsafeDecode(spanWithSchema._1 :: trace_, rr)
919+
.asInstanceOf[Z]
920+
}
921+
} else {
922+
val cases = new util.HashMap[String, (JsonError.ObjectAccess, Schema[Any])](caseNameAliases.size * 2)
923+
caseNameAliases.foreach {
924+
case (name, case_) =>
925+
cases.put(name, (JsonError.ObjectAccess(case_.id), case_.schema))
926+
}
927+
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
928+
(trace: List[JsonError], in: RetractReader) => {
929+
Lexer.char(trace, in, '{')
930+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
931+
val rr = RecordingReader(in)
932+
var index = 0
933+
while ({
934+
(Lexer.string(trace, rr).toString != discriminatorName) && {
935+
Lexer.char(trace, rr, ':')
936+
Lexer.skipValue(trace, rr)
937+
Lexer.nextField(trace, rr) || error("missing subtype", trace)
938+
}
939+
}) {
940+
index += 1
941+
}
942+
val trace_ = discriminatorSpan :: trace
943+
Lexer.char(trace_, rr, ':')
944+
val fieldValue = Lexer.string(trace_, rr).toString
945+
rr.rewind()
946+
val spanWithSchema = cases.get(fieldValue)
947+
if (spanWithSchema eq null) error("unrecognized subtype", trace_)
948+
schemaDecoder(spanWithSchema._2, index)
949+
.unsafeDecode(spanWithSchema._1 :: trace_, rr)
950+
.asInstanceOf[Z]
876951
}
877-
val trace_ = discriminatorSpan :: trace
878-
Lexer.char(trace_, rr, ':')
879-
val fieldValue = Lexer.string(trace_, rr).toString
880-
rr.rewind()
881-
val spanWithSchema = cases.get(fieldValue)
882-
if (spanWithSchema eq null) error("unrecognized subtype", trace_)
883-
schemaDecoder(spanWithSchema._2, index)
884-
.unsafeDecode(spanWithSchema._1 :: trace_, rr)
885-
.asInstanceOf[Z]
886952
}
887953
}
888954
}

0 commit comments

Comments
 (0)