Skip to content

Commit

Permalink
Improve avro support (#691)
Browse files Browse the repository at this point in the history
* Pure avro scalacheck generation

Remove the need of data serialization to generate avro specific record.
Enable logical-type conversion as described in the avro specification.

* fix sampling

* Use latest avro for test

* Fix BigDiffy

* More powerful map comparison

* Migrate to scio-0.14

* Update copyright year

* remove unused code

* Add doc

* Add extra schema check

* Prefer RHS schema as prior behaviour

* Remove unused imports

* Revert to previous behaviour
  • Loading branch information
Michel Davit authored Feb 16, 2024
1 parent 529ac52 commit 3eeebfd
Show file tree
Hide file tree
Showing 14 changed files with 471 additions and 357 deletions.
5 changes: 3 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ val algebirdVersion = "0.13.10"
// Keep in sync with Scio: https://github.com/spotify/scio/blob/v0.14.0/build.sbt
val scioVersion = "0.14.0"

val avroVersion = "1.8.2" // keep in sync with scio
val avroVersion = avroCompilerVersion // keep in sync with scio
val beamVersion = "2.53.0" // keep in sync with scio
val beamVendorVersion = "0.1" // keep in sync with scio
val bigqueryVersion = "v2-rev20230812-2.0.0" // keep in sync with scio
Expand Down Expand Up @@ -163,9 +163,9 @@ lazy val ratatoolCommon = project
name := "ratatool-common",
libraryDependencies ++= Seq(
"org.apache.avro" % "avro" % avroVersion,
"org.apache.avro" % "avro-mapred" % avroVersion classifier "hadoop2",
"com.google.guava" % "guava" % guavaVersion,
"com.google.apis" % "google-api-services-bigquery" % bigqueryVersion % Test,
"org.apache.avro" % "avro" % avroVersion % Test,
"org.apache.avro" % "avro" % avroVersion % Test classifier "tests",
"org.slf4j" % "slf4j-simple" % slf4jVersion % Test,
),
Expand Down Expand Up @@ -285,6 +285,7 @@ lazy val ratatoolScalacheck = project
name := "ratatool-scalacheck",
libraryDependencies ++= Seq(
"org.apache.avro" % "avro" % avroVersion,
"joda-time" % "joda-time" % jodaTimeVersion,
"org.scalacheck" %% "scalacheck" % scalaCheckVersion,
"org.apache.beam" % "beam-sdks-java-core" % beamVersion,
"org.apache.beam" % "beam-sdks-java-extensions-avro" % beamVersion,
Expand Down
4 changes: 1 addition & 3 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,4 @@ addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.10.0")
addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.11.1")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2")

libraryDependencies ++= Seq(
"org.apache.avro" % "avro-compiler" % "1.8.2"
)
libraryDependencies ++= Seq("org.apache.avro" % "avro-compiler" % "1.8.2")
239 changes: 144 additions & 95 deletions ratatool-diffy/src/main/scala/com/spotify/ratatool/diffy/AvroDiffy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,118 +18,167 @@
package com.spotify.ratatool.diffy

import com.spotify.scio.coders.Coder
import org.apache.avro.{Schema, SchemaValidatorBuilder}
import org.apache.avro.generic.GenericRecord
import scala.util.Try
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericRecord, IndexedRecord}
import org.apache.avro.specific.SpecificData

import scala.jdk.CollectionConverters._

/** Field level diff tool for Avro records. */
class AvroDiffy[T <: GenericRecord: Coder](
/**
* Field level diff tool for Avro records.
*
* @param ignore
* specify set of fields to ignore during comparison.
* @param unordered
* a list of fields to be treated as unordered, i.e. sort before comparison.
* @param unorderedFieldKeys
* a map of record field names to fields names that can be keyed by when comparing nested repeated
* records. (currently not support in CLI)
*/
class AvroDiffy[T <: IndexedRecord: Coder](
ignore: Set[String] = Set.empty,
unordered: Set[String] = Set.empty,
unorderedFieldKeys: Map[String, String] = Map()
) extends Diffy[T](ignore, unordered, unorderedFieldKeys) {

override def apply(x: T, y: T): Seq[Delta] = {
new SchemaValidatorBuilder().canReadStrategy
.validateLatest()
.validate(y.getSchema, List(x.getSchema).asJava)
diff(Option(x), Option(y), "")
override def apply(x: T, y: T): Seq[Delta] = (x, y) match {
case (null, null) => Seq.empty
case (_, null) => Seq(Delta("", Some(x), None, UnknownDelta))
case (null, _) => Seq(Delta("", None, Some(y), UnknownDelta))
case _ if x.getSchema != y.getSchema => Seq(Delta("", Some(x), Some(y), UnknownDelta))
case _ => diff(x, y, x.getSchema, "")
}

def isAvroRecordType(schema: Schema): Boolean =
Schema.Type.RECORD.equals(schema.getType) ||
(Schema.Type.UNION.equals(schema.getType) &&
schema.getTypes.asScala.map(_.getType).contains(Schema.Type.RECORD))

private def diff(x: Option[GenericRecord], y: Option[GenericRecord], root: String): Seq[Delta] = {
// If a y exists we assume it has the superset of all fields, since x must be backwards
// compatible with it based on the SchemaValidator check in apply()
val schemaFields = (x, y) match {
case (Some(xVal), None) => xVal.getSchema.getFields.asScala.toList
case (_, Some(yVal)) => yVal.getSchema.getFields.asScala.toList
case _ => List()
}
private def isRecord(schema: Schema): Boolean = schema.getType match {
case Schema.Type.RECORD => true
case Schema.Type.UNION => schema.getTypes.asScala.map(_.getType).contains(Schema.Type.RECORD)
case _ => false
}

schemaFields
.flatMap { f =>
val name = f.name()
val fullName = if (root.isEmpty) name else root + "." + name
getRawType(f.schema()).getType match {
case Schema.Type.RECORD =>
val a = x.flatMap(r => Option(r.get(name).asInstanceOf[GenericRecord]))
val b = y.flatMap(r => Option(r.get(name).asInstanceOf[GenericRecord]))
(a, b) match {
case (None, None) => Nil
case (Some(_), None) => Seq(Delta(fullName, a, None, UnknownDelta))
case (None, Some(_)) => Seq(Delta(fullName, None, b, UnknownDelta))
case (Some(_), Some(_)) => diff(a, b, fullName)
}
case Schema.Type.ARRAY if unordered.contains(fullName) =>
if (
unorderedFieldKeys.contains(fullName)
&& isAvroRecordType(f.schema().getElementType)
) {
val l = x
.flatMap(outer =>
Option(outer.get(name).asInstanceOf[java.util.List[GenericRecord]].asScala.toList)
)
.getOrElse(List())
.flatMap(inner =>
Try(inner.get(unorderedFieldKeys(fullName))).toOption.map(k => (k, inner))
)
.toMap
val r = y
.flatMap(outer =>
Option(outer.get(name).asInstanceOf[java.util.List[GenericRecord]].asScala.toList)
)
.getOrElse(List())
.flatMap(inner =>
Try(inner.get(unorderedFieldKeys(fullName))).toOption.map(k => (k, inner))
)
.toMap
(l.keySet ++ r.keySet).flatMap(k => diff(l.get(k), r.get(k), fullName)).toList
} else {
val a = x
.flatMap(r => Option(r.get(name).asInstanceOf[java.util.List[GenericRecord]]))
.map(sortList)
val b = y
.flatMap(r => Option(r.get(name).asInstanceOf[java.util.List[GenericRecord]]))
.map(sortList)
if (a == b) {
Nil
} else {
Seq(Delta(fullName, a, b, delta(a.orNull, b.orNull)))
}
}
case _ =>
val a = x.flatMap(r => Option(r.get(name)))
val b = y.flatMap(r => Option(r.get(name)))
if (a == b) Nil else Seq(Delta(fullName, a, b, delta(a.orNull, b.orNull)))
}
}
.filter(d => !ignore.contains(d.field))
private def isNumericType(`type`: Schema.Type): Boolean = `type` match {
case Schema.Type.INT | Schema.Type.LONG | Schema.Type.FLOAT | Schema.Type.DOUBLE => true
case _ => false
}

private def getRawType(schema: Schema): Schema = {
schema.getType match {
private def numericValue(value: AnyRef): Double = value match {
case i: java.lang.Integer => i.toDouble
case l: java.lang.Long => l.toDouble
case f: java.lang.Float => f.toDouble
case d: java.lang.Double => d
case _ => throw new IllegalArgumentException(s"Unsupported numeric type: ${value.getClass}")
}

private def diff(x: AnyRef, y: AnyRef, schema: Schema, field: String): Seq[Delta] = {
val deltas = schema.getType match {
case Schema.Type.UNION =>
val types = schema.getTypes
if (types.size == 2) {
if (types.get(0).getType == Schema.Type.NULL) {
types.get(1)
} else if (types.get(1).getType == Schema.Type.NULL) {
// incorrect use of Avro "nullable" but happens
types.get(0)
} else {
schema
// union, must resolve to same type
val data = SpecificData.get()
val xTypeIndex = data.resolveUnion(schema, x)
val yTypeIndex = data.resolveUnion(schema, y)
if (xTypeIndex != yTypeIndex) {
// Use Option as x or y can be null
Seq(Delta(field, Option(x), Option(y), UnknownDelta))
} else {
// same fields, refined schema
val fieldSchema = schema.getTypes.get(xTypeIndex)
diff(x, y, fieldSchema, field)
}

case Schema.Type.RECORD =>
// record, compare all fields
val a = x.asInstanceOf[IndexedRecord]
val b = y.asInstanceOf[IndexedRecord]
for {
f <- schema.getFields.asScala.toSeq
pos = f.pos()
name = f.name()
fullName = if (field.isEmpty) name else field + "." + name
delta <- diff(a.get(pos), b.get(pos), f.schema(), fullName)
} yield delta

case Schema.Type.ARRAY
if unorderedFieldKeys.contains(field) && isRecord(schema.getElementType) =>
// keyed array, compare like Map[String, Record]
val keyField = unorderedFieldKeys(field)
val as =
x.asInstanceOf[java.util.List[GenericRecord]].asScala.map(r => r.get(keyField) -> r).toMap
val bs =
y.asInstanceOf[java.util.List[GenericRecord]].asScala.map(r => r.get(keyField) -> r).toMap

for {
k <- (as.keySet ++ bs.keySet).toSeq
elementField = field + s"[$k]"
delta <- (as.get(k), bs.get(k)) match {
case (Some(a), Some(b)) => diff(a, b, schema.getElementType, field)
case (a, b) => Seq(Delta(field, a, b, UnknownDelta))
}
} yield delta.copy(field = delta.field.replaceFirst(field, elementField))

case Schema.Type.ARRAY =>
// array, (un)ordered comparison
val xs = x.asInstanceOf[java.util.List[AnyRef]]
val ys = y.asInstanceOf[java.util.List[AnyRef]]
val (as, bs) = if (unordered.contains(field)) {
// ordered comparison
(sortList(xs).asScala, sortList(ys).asScala)
} else {
schema
// unordered
(xs.asScala, ys.asScala)
}
case _ => schema

val delta = if (as.size != bs.size) {
Some(UnknownDelta)
} else if (isNumericType(schema.getElementType.getType) && as != bs) {
Some(VectorDelta(vectorDelta(as.map(numericValue).toSeq, bs.map(numericValue).toSeq)))
} else if (as != bs) {
as.zip(bs)
.find { case (a, b) =>
a != b && diff(a, b, schema.getElementType, field).nonEmpty
}
.map(_ => UnknownDelta)
} else {
None
}
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq

case Schema.Type.MAP =>
// map, compare key set and values
val as = x.asInstanceOf[java.util.Map[CharSequence, AnyRef]].asScala.map { case (k, v) =>
k.toString -> v
}
val bs = y.asInstanceOf[java.util.Map[CharSequence, AnyRef]].asScala.map { case (k, v) =>
k.toString -> v
}

for {
k <- (as.keySet ++ bs.keySet).toSeq
elementField = field + s"[$k]"
delta <- (as.get(k), bs.get(k)) match {
case (Some(a), Some(b)) => diff(a, b, schema.getValueType, field)
case (a, b) => Seq(Delta(field, a, b, UnknownDelta))
}
} yield delta.copy(field = delta.field.replaceFirst(field, elementField))

case Schema.Type.STRING =>
// string, convert to java String for equality check
val a = x.asInstanceOf[CharSequence].toString
val b = y.asInstanceOf[CharSequence].toString
val delta = if (a == b) None else Some(StringDelta(stringDelta(a, b)))
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq

case t if isNumericType(t) =>
// numeric, convert to Double for equality check
val a = numericValue(x)
val b = numericValue(y)
val delta = if (a == b) None else Some(NumericDelta(numericDelta(a, b)))
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq

case _ =>
// other case rely on object equality
val delta = if (x == y) None else Some(UnknownDelta)
delta.map(d => Delta(field, Some(x), Some(y), d)).toSeq
}
}

deltas.filterNot(d => ignore.contains(d.field))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ import com.spotify.scio.io.ClosedTap
import com.spotify.scio.parquet.avro._
import com.spotify.scio.values.SCollection
import com.twitter.algebird._
import org.apache.avro.{Schema, SchemaCompatibility, SchemaValidatorBuilder}
import org.apache.avro.generic.GenericRecord
import org.apache.avro.specific.SpecificRecordBase
import org.apache.beam.sdk.io.TextIO
import org.apache.beam.sdk.options.PipelineOptions
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.BaseEncoding
import org.slf4j.{Logger, LoggerFactory}

import java.nio.ByteBuffer
import java.util.Collections
import scala.annotation.tailrec
import scala.collection.mutable
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -249,14 +252,12 @@ object BigDiffy extends Command with Serializable {
(key, (Nil, diffType))
}
}
.map { x =>
x._2._2 match {
case DiffType.SAME => accSame.inc()
case DiffType.DIFFERENT => accDiff.inc()
case DiffType.MISSING_LHS => accMissingLhs.inc()
case DiffType.MISSING_RHS => accMissingRhs.inc()
}
x
.tap {
case (_, (_, DiffType.SAME)) => accSame.inc()
case (_, (_, DiffType.DIFFERENT)) => accDiff.inc()
case (_, (_, DiffType.MISSING_LHS)) => accMissingLhs.inc()
case (_, (_, DiffType.MISSING_RHS)) => accMissingRhs.inc()
case _ =>
}
}

Expand Down Expand Up @@ -608,6 +609,9 @@ object BigDiffy extends Command with Serializable {
sys.exit(1)
}

private def avroFileSchema(path: String, options: PipelineOptions): Schema =
new AvroSampler(path, conf = Some(options)).sample(1, head = true).head.getSchema

private[diffy] def avroKeyFn(keys: Seq[String]): GenericRecord => MultiKey = {
@tailrec
def get(xs: Array[String], i: Int, r: GenericRecord): String =
Expand Down Expand Up @@ -745,13 +749,22 @@ object BigDiffy extends Command with Serializable {
val result = inputMode match {
case "avro" =>
if (rowRestriction.isDefined) {
throw new IllegalArgumentException(s"rowRestriction cannot be passed for avro inputs")
throw new IllegalArgumentException("rowRestriction cannot be passed for avro inputs")
}

val lhsSchema = avroFileSchema(lhs, sc.options)
val rhsSchema = avroFileSchema(rhs, sc.options)

// validate the rhs schema can be used to read lhs
new SchemaValidatorBuilder().canReadStrategy
.validateLatest()
.validate(rhsSchema, Collections.singletonList(lhsSchema))

if (lhsSchema != rhsSchema) {
logger.warn("Schemas are different but compatible, using the rhs schema for diff")
}
val schema = rhsSchema

val schema = new AvroSampler(rhs, conf = Some(sc.options))
.sample(1, head = true)
.head
.getSchema
implicit val grCoder: Coder[GenericRecord] = avroGenericRecordCoder(schema)
val diffy = new AvroDiffy[GenericRecord](ignore, unordered, unorderedKeys)
val lhsSCollection = sc.avroFile(lhs, schema)
Expand All @@ -760,7 +773,7 @@ object BigDiffy extends Command with Serializable {
.diff[GenericRecord](lhsSCollection, rhsSCollection, diffy, avroKeyFn(keys), ignoreNan)
case "parquet" =>
if (rowRestriction.isDefined) {
throw new IllegalArgumentException(s"rowRestriction cannot be passed for Parquet inputs")
throw new IllegalArgumentException("rowRestriction cannot be passed for Parquet inputs")
}
val compatSchema = ParquetIO.getCompatibleSchemaForFiles(lhs, rhs)
val diffy = new AvroDiffy[GenericRecord](ignore, unordered, unorderedKeys)(
Expand Down
Loading

0 comments on commit 3eeebfd

Please sign in to comment.