Skip to content

Commit 3eeebfd

Browse files
author
Michel Davit
authored
Improve avro support (#691)
* Pure avro scalacheck generation Remove the need of data serialization to generate avro specific record. Enable logical-type conversion as described in the avro specification. * fix sampling * Use latest avro for test * Fix BigDiffy * More powerful map comparison * Migrate to scio-0.14 * Update copyright year * remove unused code * Add doc * Add extra schema check * Prefer RHS schema as prior behaviour * Remove unused imports * Revert to previous behaviour
1 parent 529ac52 commit 3eeebfd

File tree

14 files changed

+471
-357
lines changed

14 files changed

+471
-357
lines changed

build.sbt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ val algebirdVersion = "0.13.10"
2323
// Keep in sync with Scio: https://github.com/spotify/scio/blob/v0.14.0/build.sbt
2424
val scioVersion = "0.14.0"
2525

26-
val avroVersion = "1.8.2" // keep in sync with scio
26+
val avroVersion = avroCompilerVersion // keep in sync with scio
2727
val beamVersion = "2.53.0" // keep in sync with scio
2828
val beamVendorVersion = "0.1" // keep in sync with scio
2929
val bigqueryVersion = "v2-rev20230812-2.0.0" // keep in sync with scio
@@ -163,9 +163,9 @@ lazy val ratatoolCommon = project
163163
name := "ratatool-common",
164164
libraryDependencies ++= Seq(
165165
"org.apache.avro" % "avro" % avroVersion,
166-
"org.apache.avro" % "avro-mapred" % avroVersion classifier "hadoop2",
167166
"com.google.guava" % "guava" % guavaVersion,
168167
"com.google.apis" % "google-api-services-bigquery" % bigqueryVersion % Test,
168+
"org.apache.avro" % "avro" % avroVersion % Test,
169169
"org.apache.avro" % "avro" % avroVersion % Test classifier "tests",
170170
"org.slf4j" % "slf4j-simple" % slf4jVersion % Test,
171171
),
@@ -285,6 +285,7 @@ lazy val ratatoolScalacheck = project
285285
name := "ratatool-scalacheck",
286286
libraryDependencies ++= Seq(
287287
"org.apache.avro" % "avro" % avroVersion,
288+
"joda-time" % "joda-time" % jodaTimeVersion,
288289
"org.scalacheck" %% "scalacheck" % scalaCheckVersion,
289290
"org.apache.beam" % "beam-sdks-java-core" % beamVersion,
290291
"org.apache.beam" % "beam-sdks-java-extensions-avro" % beamVersion,

project/plugins.sbt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,4 @@ addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.10.0")
99
addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.1")
1010
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")
1111

12-
libraryDependencies ++= Seq(
13-
"org.apache.avro" % "avro-compiler" % "1.8.2"
14-
)
12+
libraryDependencies ++= Seq("org.apache.avro" % "avro-compiler" % "1.8.2")

ratatool-diffy/src/main/scala/com/spotify/ratatool/diffy/AvroDiffy.scala

Lines changed: 144 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -18,118 +18,167 @@
1818
package com.spotify.ratatool.diffy
1919

2020
import com.spotify.scio.coders.Coder
21-
import org.apache.avro.{Schema, SchemaValidatorBuilder}
22-
import org.apache.avro.generic.GenericRecord
23-
import scala.util.Try
21+
import org.apache.avro.Schema
22+
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
23+
import org.apache.avro.specific.SpecificData
2424

2525
import scala.jdk.CollectionConverters._
2626

27-
/** Field level diff tool for Avro records. */
28-
class AvroDiffy[T <: GenericRecord: Coder](
27+
/**
28+
* Field level diff tool for Avro records.
29+
*
30+
* @param ignore
31+
* specify set of fields to ignore during comparison.
32+
* @param unordered
33+
* a list of fields to be treated as unordered, i.e. sort before comparison.
34+
* @param unorderedFieldKeys
35+
* a map of record field names to fields names that can be keyed by when comparing nested repeated
36+
* records. (currently not support in CLI)
37+
*/
38+
class AvroDiffy[T <: IndexedRecord: Coder](
2939
ignore: Set[String] = Set.empty,
3040
unordered: Set[String] = Set.empty,
3141
unorderedFieldKeys: Map[String, String] = Map()
3242
) extends Diffy[T](ignore, unordered, unorderedFieldKeys) {
3343

34-
override def apply(x: T, y: T): Seq[Delta] = {
35-
new SchemaValidatorBuilder().canReadStrategy
36-
.validateLatest()
37-
.validate(y.getSchema, List(x.getSchema).asJava)
38-
diff(Option(x), Option(y), "")
44+
override def apply(x: T, y: T): Seq[Delta] = (x, y) match {
45+
case (null, null) => Seq.empty
46+
case (_, null) => Seq(Delta("", Some(x), None, UnknownDelta))
47+
case (null, _) => Seq(Delta("", None, Some(y), UnknownDelta))
48+
case _ if x.getSchema != y.getSchema => Seq(Delta("", Some(x), Some(y), UnknownDelta))
49+
case _ => diff(x, y, x.getSchema, "")
3950
}
4051

41-
def isAvroRecordType(schema: Schema): Boolean =
42-
Schema.Type.RECORD.equals(schema.getType) ||
43-
(Schema.Type.UNION.equals(schema.getType) &&
44-
schema.getTypes.asScala.map(_.getType).contains(Schema.Type.RECORD))
45-
46-
private def diff(x: Option[GenericRecord], y: Option[GenericRecord], root: String): Seq[Delta] = {
47-
// If a y exists we assume it has the superset of all fields, since x must be backwards
48-
// compatible with it based on the SchemaValidator check in apply()
49-
val schemaFields = (x, y) match {
50-
case (Some(xVal), None) => xVal.getSchema.getFields.asScala.toList
51-
case (_, Some(yVal)) => yVal.getSchema.getFields.asScala.toList
52-
case _ => List()
53-
}
52+
private def isRecord(schema: Schema): Boolean = schema.getType match {
53+
case Schema.Type.RECORD => true
54+
case Schema.Type.UNION => schema.getTypes.asScala.map(_.getType).contains(Schema.Type.RECORD)
55+
case _ => false
56+
}
5457

55-
schemaFields
56-
.flatMap { f =>
57-
val name = f.name()
58-
val fullName = if (root.isEmpty) name else root + "." + name
59-
getRawType(f.schema()).getType match {
60-
case Schema.Type.RECORD =>
61-
val a = x.flatMap(r => Option(r.get(name).asInstanceOf[GenericRecord]))
62-
val b = y.flatMap(r => Option(r.get(name).asInstanceOf[GenericRecord]))
63-
(a, b) match {
64-
case (None, None) => Nil
65-
case (Some(_), None) => Seq(Delta(fullName, a, None, UnknownDelta))
66-
case (None, Some(_)) => Seq(Delta(fullName, None, b, UnknownDelta))
67-
case (Some(_), Some(_)) => diff(a, b, fullName)
68-
}
69-
case Schema.Type.ARRAY if unordered.contains(fullName) =>
70-
if (
71-
unorderedFieldKeys.contains(fullName)
72-
&& isAvroRecordType(f.schema().getElementType)
73-
) {
74-
val l = x
75-
.flatMap(outer =>
76-
Option(outer.get(name).asInstanceOf[java.util.List[GenericRecord]].asScala.toList)
77-
)
78-
.getOrElse(List())
79-
.flatMap(inner =>
80-
Try(inner.get(unorderedFieldKeys(fullName))).toOption.map(k => (k, inner))
81-
)
82-
.toMap
83-
val r = y
84-
.flatMap(outer =>
85-
Option(outer.get(name).asInstanceOf[java.util.List[GenericRecord]].asScala.toList)
86-
)
87-
.getOrElse(List())
88-
.flatMap(inner =>
89-
Try(inner.get(unorderedFieldKeys(fullName))).toOption.map(k => (k, inner))
90-
)
91-
.toMap
92-
(l.keySet ++ r.keySet).flatMap(k => diff(l.get(k), r.get(k), fullName)).toList
93-
} else {
94-
val a = x
95-
.flatMap(r => Option(r.get(name).asInstanceOf[java.util.List[GenericRecord]]))
96-
.map(sortList)
97-
val b = y
98-
.flatMap(r => Option(r.get(name).asInstanceOf[java.util.List[GenericRecord]]))
99-
.map(sortList)
100-
if (a == b) {
101-
Nil
102-
} else {
103-
Seq(Delta(fullName, a, b, delta(a.orNull, b.orNull)))
104-
}
105-
}
106-
case _ =>
107-
val a = x.flatMap(r => Option(r.get(name)))
108-
val b = y.flatMap(r => Option(r.get(name)))
109-
if (a == b) Nil else Seq(Delta(fullName, a, b, delta(a.orNull, b.orNull)))
110-
}
111-
}
112-
.filter(d => !ignore.contains(d.field))
58+
private def isNumericType(`type`: Schema.Type): Boolean = `type` match {
59+
case Schema.Type.INT | Schema.Type.LONG | Schema.Type.FLOAT | Schema.Type.DOUBLE => true
60+
case _ => false
11361
}
11462

115-
private def getRawType(schema: Schema): Schema = {
116-
schema.getType match {
63+
private def numericValue(value: AnyRef): Double = value match {
64+
case i: java.lang.Integer => i.toDouble
65+
case l: java.lang.Long => l.toDouble
66+
case f: java.lang.Float => f.toDouble
67+
case d: java.lang.Double => d
68+
case _ => throw new IllegalArgumentException(s"Unsupported numeric type: ${value.getClass}")
69+
}
70+
71+
private def diff(x: AnyRef, y: AnyRef, schema: Schema, field: String): Seq[Delta] = {
72+
val deltas = schema.getType match {
11773
case Schema.Type.UNION =>
118-
val types = schema.getTypes
119-
if (types.size == 2) {
120-
if (types.get(0).getType == Schema.Type.NULL) {
121-
types.get(1)
122-
} else if (types.get(1).getType == Schema.Type.NULL) {
123-
// incorrect use of Avro "nullable" but happens
124-
types.get(0)
125-
} else {
126-
schema
74+
// union, must resolve to same type
75+
val data = SpecificData.get()
76+
val xTypeIndex = data.resolveUnion(schema, x)
77+
val yTypeIndex = data.resolveUnion(schema, y)
78+
if (xTypeIndex != yTypeIndex) {
79+
// Use Option as x or y can be null
80+
Seq(Delta(field, Option(x), Option(y), UnknownDelta))
81+
} else {
82+
// same fields, refined schema
83+
val fieldSchema = schema.getTypes.get(xTypeIndex)
84+
diff(x, y, fieldSchema, field)
85+
}
86+
87+
case Schema.Type.RECORD =>
88+
// record, compare all fields
89+
val a = x.asInstanceOf[IndexedRecord]
90+
val b = y.asInstanceOf[IndexedRecord]
91+
for {
92+
f <- schema.getFields.asScala.toSeq
93+
pos = f.pos()
94+
name = f.name()
95+
fullName = if (field.isEmpty) name else field + "." + name
96+
delta <- diff(a.get(pos), b.get(pos), f.schema(), fullName)
97+
} yield delta
98+
99+
case Schema.Type.ARRAY
100+
if unorderedFieldKeys.contains(field) && isRecord(schema.getElementType) =>
101+
// keyed array, compare like Map[String, Record]
102+
val keyField = unorderedFieldKeys(field)
103+
val as =
104+
x.asInstanceOf[java.util.List[GenericRecord]].asScala.map(r => r.get(keyField) -> r).toMap
105+
val bs =
106+
y.asInstanceOf[java.util.List[GenericRecord]].asScala.map(r => r.get(keyField) -> r).toMap
107+
108+
for {
109+
k <- (as.keySet ++ bs.keySet).toSeq
110+
elementField = field + s"[$k]"
111+
delta <- (as.get(k), bs.get(k)) match {
112+
case (Some(a), Some(b)) => diff(a, b, schema.getElementType, field)
113+
case (a, b) => Seq(Delta(field, a, b, UnknownDelta))
127114
}
115+
} yield delta.copy(field = delta.field.replaceFirst(field, elementField))
116+
117+
case Schema.Type.ARRAY =>
118+
// array, (un)ordered comparison
119+
val xs = x.asInstanceOf[java.util.List[AnyRef]]
120+
val ys = y.asInstanceOf[java.util.List[AnyRef]]
121+
val (as, bs) = if (unordered.contains(field)) {
122+
// ordered comparison
123+
(sortList(xs).asScala, sortList(ys).asScala)
128124
} else {
129-
schema
125+
// unordered
126+
(xs.asScala, ys.asScala)
130127
}
131-
case _ => schema
128+
129+
val delta = if (as.size != bs.size) {
130+
Some(UnknownDelta)
131+
} else if (isNumericType(schema.getElementType.getType) && as != bs) {
132+
Some(VectorDelta(vectorDelta(as.map(numericValue).toSeq, bs.map(numericValue).toSeq)))
133+
} else if (as != bs) {
134+
as.zip(bs)
135+
.find { case (a, b) =>
136+
a != b && diff(a, b, schema.getElementType, field).nonEmpty
137+
}
138+
.map(_ => UnknownDelta)
139+
} else {
140+
None
141+
}
142+
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq
143+
144+
case Schema.Type.MAP =>
145+
// map, compare key set and values
146+
val as = x.asInstanceOf[java.util.Map[CharSequence, AnyRef]].asScala.map { case (k, v) =>
147+
k.toString -> v
148+
}
149+
val bs = y.asInstanceOf[java.util.Map[CharSequence, AnyRef]].asScala.map { case (k, v) =>
150+
k.toString -> v
151+
}
152+
153+
for {
154+
k <- (as.keySet ++ bs.keySet).toSeq
155+
elementField = field + s"[$k]"
156+
delta <- (as.get(k), bs.get(k)) match {
157+
case (Some(a), Some(b)) => diff(a, b, schema.getValueType, field)
158+
case (a, b) => Seq(Delta(field, a, b, UnknownDelta))
159+
}
160+
} yield delta.copy(field = delta.field.replaceFirst(field, elementField))
161+
162+
case Schema.Type.STRING =>
163+
// string, convert to java String for equality check
164+
val a = x.asInstanceOf[CharSequence].toString
165+
val b = y.asInstanceOf[CharSequence].toString
166+
val delta = if (a == b) None else Some(StringDelta(stringDelta(a, b)))
167+
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq
168+
169+
case t if isNumericType(t) =>
170+
// numeric, convert to Double for equality check
171+
val a = numericValue(x)
172+
val b = numericValue(y)
173+
val delta = if (a == b) None else Some(NumericDelta(numericDelta(a, b)))
174+
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq
175+
176+
case _ =>
177+
// other case rely on object equality
178+
val delta = if (x == y) None else Some(UnknownDelta)
179+
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq
132180
}
133-
}
134181

182+
deltas.filterNot(d => ignore.contains(d.field))
183+
}
135184
}

ratatool-diffy/src/main/scala/com/spotify/ratatool/diffy/BigDiffy.scala

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@ import com.spotify.scio.io.ClosedTap
3434
import com.spotify.scio.parquet.avro._
3535
import com.spotify.scio.values.SCollection
3636
import com.twitter.algebird._
37+
import org.apache.avro.{Schema, SchemaCompatibility, SchemaValidatorBuilder}
3738
import org.apache.avro.generic.GenericRecord
3839
import org.apache.avro.specific.SpecificRecordBase
3940
import org.apache.beam.sdk.io.TextIO
41+
import org.apache.beam.sdk.options.PipelineOptions
4042
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding
4143
import org.slf4j.{Logger, LoggerFactory}
4244

4345
import java.nio.ByteBuffer
46+
import java.util.Collections
4447
import scala.annotation.tailrec
4548
import scala.collection.mutable
4649
import scala.jdk.CollectionConverters._
@@ -249,14 +252,12 @@ object BigDiffy extends Command with Serializable {
249252
(key, (Nil, diffType))
250253
}
251254
}
252-
.map { x =>
253-
x._2._2 match {
254-
case DiffType.SAME => accSame.inc()
255-
case DiffType.DIFFERENT => accDiff.inc()
256-
case DiffType.MISSING_LHS => accMissingLhs.inc()
257-
case DiffType.MISSING_RHS => accMissingRhs.inc()
258-
}
259-
x
255+
.tap {
256+
case (_, (_, DiffType.SAME)) => accSame.inc()
257+
case (_, (_, DiffType.DIFFERENT)) => accDiff.inc()
258+
case (_, (_, DiffType.MISSING_LHS)) => accMissingLhs.inc()
259+
case (_, (_, DiffType.MISSING_RHS)) => accMissingRhs.inc()
260+
case _ =>
260261
}
261262
}
262263

@@ -608,6 +609,9 @@ object BigDiffy extends Command with Serializable {
608609
sys.exit(1)
609610
}
610611

612+
private def avroFileSchema(path: String, options: PipelineOptions): Schema =
613+
new AvroSampler(path, conf = Some(options)).sample(1, head = true).head.getSchema
614+
611615
private[diffy] def avroKeyFn(keys: Seq[String]): GenericRecord => MultiKey = {
612616
@tailrec
613617
def get(xs: Array[String], i: Int, r: GenericRecord): String =
@@ -745,13 +749,22 @@ object BigDiffy extends Command with Serializable {
745749
val result = inputMode match {
746750
case "avro" =>
747751
if (rowRestriction.isDefined) {
748-
throw new IllegalArgumentException(s"rowRestriction cannot be passed for avro inputs")
752+
throw new IllegalArgumentException("rowRestriction cannot be passed for avro inputs")
753+
}
754+
755+
val lhsSchema = avroFileSchema(lhs, sc.options)
756+
val rhsSchema = avroFileSchema(rhs, sc.options)
757+
758+
// validate the rhs schema can be used to read lhs
759+
new SchemaValidatorBuilder().canReadStrategy
760+
.validateLatest()
761+
.validate(rhsSchema, Collections.singletonList(lhsSchema))
762+
763+
if (lhsSchema != rhsSchema) {
764+
logger.warn("Schemas are different but compatible, using the rhs schema for diff")
749765
}
766+
val schema = rhsSchema
750767

751-
val schema = new AvroSampler(rhs, conf = Some(sc.options))
752-
.sample(1, head = true)
753-
.head
754-
.getSchema
755768
implicit val grCoder: Coder[GenericRecord] = avroGenericRecordCoder(schema)
756769
val diffy = new AvroDiffy[GenericRecord](ignore, unordered, unorderedKeys)
757770
val lhsSCollection = sc.avroFile(lhs, schema)
@@ -760,7 +773,7 @@ object BigDiffy extends Command with Serializable {
760773
.diff[GenericRecord](lhsSCollection, rhsSCollection, diffy, avroKeyFn(keys), ignoreNan)
761774
case "parquet" =>
762775
if (rowRestriction.isDefined) {
763-
throw new IllegalArgumentException(s"rowRestriction cannot be passed for Parquet inputs")
776+
throw new IllegalArgumentException("rowRestriction cannot be passed for Parquet inputs")
764777
}
765778
val compatSchema = ParquetIO.getCompatibleSchemaForFiles(lhs, rhs)
766779
val diffy = new AvroDiffy[GenericRecord](ignore, unordered, unorderedKeys)(

0 commit comments

Comments
 (0)