Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ lazy val test =
(project in file("zio-spark-test"))
.settings(crossScalaVersionSettings)
.settings(commonSettings)
.settings(macroExpansionSettings)
.settings(macroDefinitionSettings)
.settings(
name := "zio-spark-test",
scalaMajorVersion := CrossVersion.partialVersion(scalaVersion.value).get._1,
Expand Down Expand Up @@ -310,11 +312,39 @@ lazy val commonSettings =
* root (./) */
/* which lead to errors, eg. Path does not exist:
* file:./zio-spark/examples/simple-app/examples/simple-app/src/main/resources/data.csv */
lazy val noPublishingSettings =
Seq(
fork := false,
publish / skip := true,
// Don't generate documentation for the examples
Compile / doc / sources := Seq.empty,
Compile / packageDoc / publishArtifact := false
)
lazy val noPublishingSettings = Seq(
fork := false,
publish / skip := true,
// Don't generate documentation for the examples
Compile / doc / sources := Seq.empty,
Compile / packageDoc / publishArtifact := false
)

def macroExpansionSettings = Seq(
scalacOptions ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, 13)) => Seq("-Ymacro-annotations")
case _ => Seq.empty
}
},
libraryDependencies ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, x)) if x <= 12 =>
Seq(compilerPlugin(("org.scalamacros" % "paradise" % "2.1.1").cross(CrossVersion.full)))
case _ => Seq.empty
}
}
)

def macroDefinitionSettings = Seq(
scalacOptions += "-language:experimental.macros",
libraryDependencies ++= {
CrossVersion.partialVersion(scalaVersion.value).get._1 match {
case 3 => Seq.empty
case _ => Seq(
"org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pb, the user need the macro as a dependencie, as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided"
)
}
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package zio.spark.test

import zio.spark.sql.SIO
import zio.test.TestResult

trait CompileVariants {
def assertSpark[A, B](value: => A)(assertion: SparkAssertion[A, B]): SIO[TestResult] = macro Macros.assert_impl

def assertZIOSpark[A, B](value: SIO[A])(assertion: SparkAssertion[A, B]): SIO[TestResult] =
macro Macros.assertZIO_impl
}
77 changes: 77 additions & 0 deletions zio-spark-test/src/main/scala-2/zio/spark/test/Macros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package zio.spark.test

import scala.reflect.macros._

// Pilfered (with immense gratitude & minor modifications)
// from https://github.com/zio/zio/blob/series/2.x/test/shared/src/main/scala-2/zio/test/Macros.scala

// scalafix:off
private[test] object Macros {
def assert_impl(c: blackbox.Context)(value: c.Tree)(assertion: c.Tree): c.Tree = {
import c.universe._

// Pilfered (with immense gratitude & minor modifications)
// from https://github.com/com-lihaoyi/sourcecode
def text[T: c.WeakTypeTag](tree: c.Tree): (Int, Int, String) = {
val fileContent = new String(tree.pos.source.content)

var start =
tree.collect { case treeVal =>
treeVal.pos match {
case NoPosition => Int.MaxValue
case p => p.start
}
}.min
val initialStart = start

// Moves to the true beginning of the expression, in the case where the
// internal expression is wrapped in parens.
while ((start - 2) >= 0 && fileContent(start - 2) == '(')
start -= 1

val g = c.asInstanceOf[reflect.macros.runtime.Context].global
val parser = g.newUnitParser(fileContent.drop(start))
parser.expr()
val end = parser.in.lastOffset
(initialStart - start, start, fileContent.slice(start, start + end))
}

val codeString = text(value)._3
val assertionString = text(assertion)._3
q"_root_.zio.spark.test.assertSparkImpl($value, $codeString, $assertionString)($assertion)"
}

def assertZIO_impl(c: blackbox.Context)(value: c.Tree)(assertion: c.Tree): c.Tree = {
import c.universe._

// Pilfered (with immense gratitude & minor modifications)
// from https://github.com/com-lihaoyi/sourcecode
def text[T: c.WeakTypeTag](tree: c.Tree): (Int, Int, String) = {
val fileContent = new String(tree.pos.source.content)

var start =
tree.collect { case treeVal =>
treeVal.pos match {
case NoPosition => Int.MaxValue
case p => p.start
}
}.min
val initialStart = start

// Moves to the true beginning of the expression, in the case where the
// internal expression is wrapped in parens.
while ((start - 2) >= 0 && fileContent(start - 2) == '(')
start -= 1

val g = c.asInstanceOf[reflect.macros.runtime.Context].global
val parser = g.newUnitParser(fileContent.drop(start))
parser.expr()
val end = parser.in.lastOffset
(initialStart - start, start, fileContent.slice(start, start + end))
}

val codeString = text(value)._3
val assertionString = text(assertion)._3
q"_root_.zio.spark.test.assertZIOSparkImpl($value, $codeString, $assertionString)($assertion)"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package zio.spark.test

import zio.Trace
import zio.spark.sql.SIO
import zio.test.TestResult
import zio.internal.stacktracer.SourceLocation

trait CompileVariants {
inline def assertSpark[A, B](inline value: => A) (inline assertion: SparkAssertion[A, B])(implicit trace: Trace, sourceLocation: SourceLocation):SIO[TestResult] =
${Macros.assert_impl('value)('assertion, 'trace, 'sourceLocation)}

inline def assertZIOSpark[A, B](inline value: SIO[A])(inline assertion: SparkAssertion[A, B])(implicit trace: Trace, sourceLocation: SourceLocation): SIO[TestResult] =
${Macros.assertZIO_impl('value)('assertion, 'trace, 'sourceLocation)}
}

object CompileVariants {
def assertSparkProxy[A, B](
value: => A,
codePart: String,
assertionPart: String
)(
assertion: SparkAssertion[A, B]
)(implicit
trace: Trace,
sourceLocation: SourceLocation
) = zio.spark.test.assertSparkImpl(value, codePart, assertionPart)(assertion)

def assertZIOSparkProxy[A, B](
value: SIO[A],
codePart: String,
assertionPart: String
)(
assertion: SparkAssertion[A, B]
)(implicit
trace: Trace,
sourceLocation: SourceLocation
) = zio.spark.test.assertZIOSparkImpl(value, codePart, assertionPart)(assertion)
}
38 changes: 38 additions & 0 deletions zio-spark-test/src/main/scala-3/zio/spark/test/Macros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package zio.spark.test

import scala.reflect.macros._

import zio._
import zio.test.TestResult
import zio.test.Macros.showExpr
import zio.internal.stacktracer.SourceLocation

import zio.spark.sql.SIO
import zio.spark.test.SparkAssertion

import scala.quoted._

// Pilfered (with immense gratitude & minor modifications)
// from https://github.com/zio/zio/blob/series/2.x/test/shared/src/main/scala-3/zio/test/Macros.scala
object Macros {
def assert_impl[A: Type, B: Type](value: Expr[A])(assertion: Expr[SparkAssertion[A, B]], trace: Expr[Trace], sourceLocation: Expr[SourceLocation])(using Quotes): Expr[SIO[TestResult]] = {
import quotes.reflect._
val codeString = showExpr(value)
val assertionString = showExpr(assertion)
'{_root_.zio.spark.test.CompileVariants.assertSparkProxy($value, ${Expr(codeString)}, ${Expr(assertionString)})($assertion)($trace, $sourceLocation)
}
}

def assertZIO_impl[A: Type, B: Type](value: Expr[SIO[A]])(assertion: Expr[SparkAssertion[A, B]], trace: Expr[Trace], sourceLocation: Expr[SourceLocation])(using Quotes): Expr[SIO[TestResult]] = {
import quotes.reflect._
val codeString = showExpr(value)
val assertionString = showExpr(assertion)
'{_root_.zio.spark.test.CompileVariants.assertZIOSparkProxy($value, ${Expr(codeString)}, ${Expr(assertionString)})($assertion)($trace, $sourceLocation)
}
}

def showExpr[A](expr: Expr[A])(using Quotes): String = {
import quotes.reflect._
expr.asTerm.pos.sourceCode.get
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ abstract class SharedZIOSparkSpecDefault extends ZIOSpec[SparkSession] {
def ss = defaultSparkSession

@SuppressWarnings(Array("scalafix:DisableSyntax.valInAbstract"))
override val bootstrap: TaskLayer[SparkSession] =
ss.asLayer
.tap(_.get.sparkContext.setLogLevel("ERROR"))
override val bootstrap: TaskLayer[SparkSession] = ss.asLayer.tap(_.get.sparkContext.setLogLevel("ERROR"))
}
41 changes: 41 additions & 0 deletions zio-spark-test/src/main/scala/zio/spark/test/SparkAssertion.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package zio.spark.test

import zio.internal.ansi.AnsiStringOps
import zio.internal.stacktracer.SourceLocation
import zio.spark.sql.{Dataset, SIO}
import zio.spark.sql.TryAnalysis.syntax.throwAnalysisException
import zio.test.{Assertion, TestArrow, TestResult}

/**
* Todo: instruction is just the content of f as a string such as
* {{{_.isEmpty}}} becomes {{{"isEmpty"}}}. We should use a macro
* instead.
*/
final case class SparkAssertion[A, B](f: A => SIO[B], assertion: Assertion[B], instruction: String)

object SparkAssertion {
private[test] def smartAssert[A](
expr: => A,
codePart: String,
assertionPart: String,
instructionPart: String
)(
assertion: Assertion[A]
)(implicit sourceLocation: SourceLocation): TestResult = {
lazy val value0 = expr
val completeString = codePart.blue + " did not satisfy " + assertionPart.cyan
val instructionString = s"$codePart.$instructionPart"

TestResult(
(TestArrow.succeed(value0).withCode(instructionString) >>> assertion.arrow).withLocation
.withCompleteCode(completeString)
)
}

def isEmpty[A]: SparkAssertion[Dataset[A], Boolean] = SparkAssertion(_.isEmpty, Assertion.isTrue, "isEmpty")
def shouldExist[A](expr: String): SparkAssertion[Dataset[A], Boolean] =
SparkAssertion(_.filter(expr).isEmpty, Assertion.isFalse, s"""filter("$expr").isEmpty""")

def shouldNotExist[A](expr: String): SparkAssertion[Dataset[A], Boolean] =
shouldExist[A](expr).copy(assertion = Assertion.isTrue)
}
33 changes: 31 additions & 2 deletions zio-spark-test/src/main/scala/zio/spark/test/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package zio.spark

import org.apache.spark.sql.Encoder

import zio.Trace
import zio._
import zio.internal.stacktracer.SourceLocation
import zio.spark.parameter._
import zio.spark.rdd.RDD
import zio.spark.sql._
import zio.spark.sql.implicits._

import scala.reflect.ClassTag

package object test {
package object test extends CompileVariants {
val defaultSparkSession: SparkSession.Builder =
SparkSession.builder
.master(localAllNodes)
Expand All @@ -20,4 +21,32 @@ package object test {
def Dataset[T: Encoder](values: T*)(implicit trace: Trace): SIO[Dataset[T]] = values.toDataset

def RDD[T: ClassTag](values: T*)(implicit trace: Trace): SIO[RDD[T]] = values.toRDD

private[test] def assertZIOSparkImpl[A, B](
value: SIO[A],
codePart: String,
assertionPart: String
)(
assertion: SparkAssertion[A, B]
)(implicit
trace: Trace,
sourceLocation: SourceLocation
) =
value.flatMap(assertion.f).map { a =>
SparkAssertion.smartAssert(a, codePart, assertionPart, assertion.instruction)(assertion.assertion)
}

private[test] def assertSparkImpl[A, B](
value: => A,
codePart: String,
assertionPart: String
)(
assertion: SparkAssertion[A, B]
)(implicit
trace: Trace,
sourceLocation: SourceLocation
) =
assertion.f(value).map { a =>
SparkAssertion.smartAssert(a, codePart, assertionPart, assertion.instruction)(assertion.assertion)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package zio.spark.test

import scala3encoders.given // scalafix:ok

import zio.spark.sql.implicits._
import zio.spark.test.SparkAssertion._
import zio.test.TestAspect.failing

object SparkAssertionSpec extends SharedZIOSparkSpecDefault {

override def spec =
suite("SparkAssertion spec")(
test("Assertions should work with assertSpark") {
for {
df <- Dataset[Int]()
result <- assertSpark(df)(isEmpty)
} yield result
},
test("It should assert that a dataset is empty") {
assertZIOSpark(Dataset[Int]())(isEmpty)
},
test("It should fail asserting that a dataset is empty") {
assertZIOSpark(Dataset(1, 2, 3))(isEmpty)
} @@ failing,
test("It should assert that at least a row respect the predicate") {
assertZIOSpark(Dataset(1, 2, 3))(shouldExist("value == 1"))
},
test("It should fail asserting that at least a row respect the predicate") {
assertZIOSpark(Dataset(1, 2, 3))(shouldExist("value == 4"))
} @@ failing,
test("It should assert that no rows respect the predicate") {
assertZIOSpark(Dataset(1, 2, 3))(shouldNotExist("value == 4"))
},
test("It should fail asserting that no rows respect the predicate") {
assertZIOSpark(Dataset(1, 2, 3))(shouldNotExist("value == 1"))
} @@ failing
)
}

This file was deleted.