Skip to content

Commit 13e7fb8

Browse files
committed
Use StringMatrix for enum decoding to reduce allocations and CPU usage
1 parent bb476c4 commit 13e7fb8

File tree

1 file changed

+122
-58
lines changed

1 file changed

+122
-58
lines changed

zio-schema-json/shared/src/main/scala/zio/schema/codec/JsonCodec.scala

Lines changed: 122 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,20 @@ import java.nio.CharBuffer
44
import java.nio.charset.StandardCharsets
55
import java.util
66
import java.util.concurrent.ConcurrentHashMap
7-
87
import scala.annotation.switch
98
import scala.collection.immutable.ListMap
109
import scala.collection.mutable
1110
import scala.util.control.NonFatal
12-
1311
import zio.json.JsonCodec._
1412
import zio.json.JsonDecoder.{ JsonError, UnsafeJson }
1513
import zio.json.ast.Json
1614
import zio.json.internal.{ Lexer, RecordingReader, RetractReader, StringMatrix, WithRecordingReader, Write }
1715
import zio.json.{
16+
JsonFieldDecoder,
17+
JsonFieldEncoder,
1818
JsonCodec => ZJsonCodec,
1919
JsonDecoder => ZJsonDecoder,
20-
JsonEncoder => ZJsonEncoder,
21-
JsonFieldDecoder,
22-
JsonFieldEncoder
20+
JsonEncoder => ZJsonEncoder
2321
}
2422
import zio.prelude.NonEmptyMap
2523
import zio.schema.Schema.GenericRecord
@@ -800,18 +798,33 @@ object JsonCodec {
800798
throw UnsafeJson(JsonError.Message(msg) :: trace)
801799

802800
if (parentSchema.cases.forall(_.schema.isInstanceOf[Schema.CaseClass0[_]])) { // if all cases are CaseClass0, decode as String
803-
new ZJsonDecoder[Z] {
804-
private[this] val cases = new util.HashMap[String, Z](caseNameAliases.size * 2)
805-
806-
caseNameAliases.foreach {
807-
case (name, case_) =>
808-
cases.put(name, case_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct())
801+
if (caseNameAliases.size <= 64) {
802+
new ZJsonDecoder[Z] {
803+
private[this] val stringMatrix = new StringMatrix(caseNameAliases.keys.toArray)
804+
private[this] val cases = caseNameAliases.values
805+
.map(_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct())
806+
.toVector
807+
808+
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
809+
val idx = Lexer.enumeration(trace, in, stringMatrix)
810+
if (idx < 0) error("unrecognized string", trace)
811+
cases(idx)
812+
}
809813
}
814+
} else {
815+
new ZJsonDecoder[Z] {
816+
private[this] val cases = new util.HashMap[String, Z](caseNameAliases.size * 2)
810817

811-
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
812-
val result = cases.get(Lexer.string(trace, in).toString)
813-
if (result == null) error("unrecognized string", trace)
814-
result
818+
caseNameAliases.foreach {
819+
case (name, case_) =>
820+
cases.put(name, case_.schema.asInstanceOf[Schema.CaseClass0[Z]].defaultConstruct())
821+
}
822+
823+
override def unsafeDecode(trace: List[JsonError], in: RetractReader): Z = {
824+
val result = cases.get(Lexer.string(trace, in).toString)
825+
if (result == null) error("unrecognized string", trace)
826+
result
827+
}
815828
}
816829
}
817830
} else if (parentSchema.annotations.exists(_.isInstanceOf[noDiscriminator])) {
@@ -836,53 +849,104 @@ object JsonCodec {
836849
} else {
837850
parentSchema.annotations.collectFirst { case d: discriminatorName => d.tag } match {
838851
case None =>
839-
val cases = new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size * 2)
840-
caseNameAliases.foreach {
841-
case (name, case_) =>
842-
cases.put(name, (JsonError.ObjectAccess(case_.id), schemaDecoder(case_.schema)))
843-
}
844-
(trace: List[JsonError], in: RetractReader) => {
845-
Lexer.char(trace, in, '{')
846-
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
847-
val fieldName = Lexer.string(trace, in).toString
848-
val spanWithDecoder = cases.get(fieldName)
849-
if (spanWithDecoder eq null) error("unrecognized subtype", trace)
850-
val trace_ = spanWithDecoder._1 :: trace
851-
Lexer.char(trace_, in, ':')
852-
val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf[Z]
853-
Lexer.nextField(trace_, in)
854-
decoded
852+
if (caseNameAliases.size <= 64) {
853+
val stringMatrix = new StringMatrix(caseNameAliases.keys.toArray)
854+
val cases = caseNameAliases.values.map { case_ =>
855+
(JsonError.ObjectAccess(case_.id), schemaDecoder(case_.schema))
856+
}.toVector
857+
(trace: List[JsonError], in: RetractReader) => {
858+
Lexer.char(trace, in, '{')
859+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
860+
val idx = Lexer.enumeration(trace, in, stringMatrix)
861+
if (idx < 0) error("unrecognized subtype", trace)
862+
val spanWithDecoder = cases(idx)
863+
val trace_ = spanWithDecoder._1 :: trace
864+
Lexer.char(trace_, in, ':')
865+
val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf[Z]
866+
Lexer.nextField(trace_, in)
867+
decoded
868+
}
869+
} else {
870+
val cases =
871+
new util.HashMap[String, (JsonError.ObjectAccess, ZJsonDecoder[Any])](caseNameAliases.size * 2)
872+
caseNameAliases.foreach {
873+
case (name, case_) =>
874+
cases.put(name, (JsonError.ObjectAccess(case_.id), schemaDecoder(case_.schema)))
875+
}
876+
(trace: List[JsonError], in: RetractReader) => {
877+
Lexer.char(trace, in, '{')
878+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
879+
val fieldName = Lexer.string(trace, in).toString
880+
val spanWithDecoder = cases.get(fieldName)
881+
if (spanWithDecoder eq null) error("unrecognized subtype", trace)
882+
val trace_ = spanWithDecoder._1 :: trace
883+
Lexer.char(trace_, in, ':')
884+
val decoded = spanWithDecoder._2.unsafeDecode(trace_, in).asInstanceOf[Z]
885+
Lexer.nextField(trace_, in)
886+
decoded
887+
}
855888
}
856889
case Some(discriminatorName) =>
857-
val cases = new util.HashMap[String, (JsonError.ObjectAccess, Schema[Any])](caseNameAliases.size * 2)
858-
caseNameAliases.foreach {
859-
case (name, case_) =>
860-
cases.put(name, (JsonError.ObjectAccess(case_.id), case_.schema))
861-
}
862-
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
863-
(trace: List[JsonError], in: RetractReader) => {
864-
Lexer.char(trace, in, '{')
865-
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
866-
val rr = RecordingReader(in)
867-
var index = 0
868-
while ({
869-
(Lexer.string(trace, rr).toString != discriminatorName) && {
870-
Lexer.char(trace, rr, ':')
871-
Lexer.skipValue(trace, rr)
872-
Lexer.nextField(trace, rr) || error("missing subtype", trace)
890+
if (caseNameAliases.size <= 64) {
891+
val discriminatorMatrix = new StringMatrix(Array(discriminatorName))
892+
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
893+
val caseMatrix = new StringMatrix(caseNameAliases.keys.toArray)
894+
val cases = caseNameAliases.values.map(case_ => (JsonError.ObjectAccess(case_.id), case_.schema)).toVector
895+
(trace: List[JsonError], in: RetractReader) => {
896+
Lexer.char(trace, in, '{')
897+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
898+
val rr = RecordingReader(in)
899+
var index = 0
900+
while ({
901+
(Lexer.enumeration(trace, rr, discriminatorMatrix) < 0) && {
902+
Lexer.char(trace, rr, ':')
903+
Lexer.skipValue(trace, rr)
904+
Lexer.nextField(trace, rr) || error("missing subtype", trace)
905+
}
906+
}) {
907+
index += 1
873908
}
874-
}) {
875-
index += 1
909+
val trace_ = discriminatorSpan :: trace
910+
Lexer.char(trace_, rr, ':')
911+
val idx = Lexer.enumeration(trace_, rr, caseMatrix)
912+
rr.rewind()
913+
if (idx < 0) error("unrecognized subtype", trace_)
914+
val spanWithSchema = cases(idx)
915+
schemaDecoder(spanWithSchema._2, index)
916+
.unsafeDecode(spanWithSchema._1 :: trace_, rr)
917+
.asInstanceOf[Z]
918+
}
919+
} else {
920+
val cases = new util.HashMap[String, (JsonError.ObjectAccess, Schema[Any])](caseNameAliases.size * 2)
921+
caseNameAliases.foreach {
922+
case (name, case_) =>
923+
cases.put(name, (JsonError.ObjectAccess(case_.id), case_.schema))
924+
}
925+
val discriminatorSpan = JsonError.ObjectAccess(discriminatorName)
926+
(trace: List[JsonError], in: RetractReader) => {
927+
Lexer.char(trace, in, '{')
928+
if (!Lexer.firstField(trace, in)) error("missing subtype", trace)
929+
val rr = RecordingReader(in)
930+
var index = 0
931+
while ({
932+
(Lexer.string(trace, rr).toString != discriminatorName) && {
933+
Lexer.char(trace, rr, ':')
934+
Lexer.skipValue(trace, rr)
935+
Lexer.nextField(trace, rr) || error("missing subtype", trace)
936+
}
937+
}) {
938+
index += 1
939+
}
940+
val trace_ = discriminatorSpan :: trace
941+
Lexer.char(trace_, rr, ':')
942+
val fieldValue = Lexer.string(trace_, rr).toString
943+
rr.rewind()
944+
val spanWithSchema = cases.get(fieldValue)
945+
if (spanWithSchema eq null) error("unrecognized subtype", trace_)
946+
schemaDecoder(spanWithSchema._2, index)
947+
.unsafeDecode(spanWithSchema._1 :: trace_, rr)
948+
.asInstanceOf[Z]
876949
}
877-
val trace_ = discriminatorSpan :: trace
878-
Lexer.char(trace_, rr, ':')
879-
val fieldValue = Lexer.string(trace_, rr).toString
880-
rr.rewind()
881-
val spanWithSchema = cases.get(fieldValue)
882-
if (spanWithSchema eq null) error("unrecognized subtype", trace_)
883-
schemaDecoder(spanWithSchema._2, index)
884-
.unsafeDecode(spanWithSchema._1 :: trace_, rr)
885-
.asInstanceOf[Z]
886950
}
887951
}
888952
}

0 commit comments

Comments
 (0)