-
Notifications
You must be signed in to change notification settings - Fork 920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SPARKC-706: Add basic support for Cassandra vectors #1366
Changes from 1 commit
67c41f3
d216e40
b1b8dbb
0fba182
fea2355
3316258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,7 +98,7 @@ trait SparkCassandraITSpecBase | |
} | ||
|
||
override def withFixture(test: NoArgTest): Outcome = wrapUnserializableExceptions { | ||
super.withFixture(test) | ||
super.withFixture(test) | ||
} | ||
|
||
def getKsName = { | ||
|
@@ -147,16 +147,24 @@ trait SparkCassandraITSpecBase | |
|
||
/** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version | ||
* (if this is a DSE cluster) */ | ||
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = { | ||
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Some(cassandra), Some(dse))(f) | ||
|
||
def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = { | ||
if (isDse(conn)) { | ||
from(dse)(f) | ||
dse match { | ||
case Some(dseVersion) => from(dseVersion)(f) | ||
case None => report(s"Skipped because not DSE") | ||
} | ||
} else { | ||
from(cassandra)(f) | ||
cassandra match { | ||
case Some(cassandraVersion) => from(cassandraVersion)(f) | ||
case None => report(s"Skipped because not Cassandra") | ||
} | ||
} | ||
} | ||
|
||
/** Skips the given test if the Cluster Version is lower or equal to the given version */ | ||
def from(version: Version)(f: => Unit): Unit = { | ||
private def from(version: Version)(f: => Unit): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doc is not correct, right? It skips only when the version is lower. |
||
skip(cluster.getCassandraVersion, version) { f } | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
package com.datastax.spark.connector.rdd.typeTests | ||
|
||
import com.datastax.oss.driver.api.core.cql.Row | ||
import com.datastax.oss.driver.api.core.{CqlSession, Version} | ||
import com.datastax.spark.connector._ | ||
import com.datastax.spark.connector.ccm.CcmConfig | ||
import com.datastax.spark.connector.cluster.DefaultCluster | ||
import com.datastax.spark.connector.cql.CassandraConnector | ||
import com.datastax.spark.connector.datasource.CassandraCatalog | ||
import com.datastax.spark.connector.mapper.ColumnMapper | ||
import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType} | ||
import com.datastax.spark.connector.rdd.reader.RowReaderFactory | ||
import com.datastax.spark.connector.types.TypeConverter | ||
import org.apache.spark.sql.{SaveMode, SparkSession} | ||
import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper} | ||
|
||
import scala.collection.convert.ImplicitConversionsToScala._ | ||
import scala.collection.immutable | ||
import scala.reflect.ClassTag | ||
import scala.reflect.runtime.universe._ | ||
import scala.reflect._ | ||
|
||
|
||
abstract class VectorTypeTest[ | ||
ScalaType: ClassTag : TypeTag, | ||
DriverType <: Number : ClassTag, | ||
CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster | ||
{ | ||
/** Skips the given test if the cluster is not Cassandra */ | ||
override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove |
||
|
||
override lazy val conn = CassandraConnector(sparkConf) | ||
|
||
val VectorTable = "vectors" | ||
|
||
def createVectorTable(session: CqlSession, table: String): Unit = { | ||
session.execute( | ||
s"""CREATE TABLE IF NOT EXISTS $ks.$table ( | ||
| id INT PRIMARY KEY, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can a vector be a primary key? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps, why do you ask? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering if that's supported. |
||
| v VECTOR<$typeName, 3> | ||
|)""".stripMargin) | ||
} | ||
|
||
def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1")) | ||
|
||
def minDSEVersion: Option[Version] = None | ||
|
||
def vectorFromInts(ints: Int*): Seq[ScalaType] | ||
|
||
def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType | ||
|
||
override def beforeClass() { | ||
conn.withSessionDo { session => | ||
session.execute( | ||
s"""CREATE KEYSPACE IF NOT EXISTS $ks | ||
|WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }""" | ||
.stripMargin) | ||
} | ||
} | ||
|
||
private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe assertVectors would be more clear as the method doesn't return boolean, instead it asserts on the content. |
||
val returnedVectors = for (i <- expectedVectors.indices) yield { | ||
rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq | ||
} | ||
|
||
returnedVectors should contain theSameElementsInOrderAs expectedVectors | ||
} | ||
|
||
"SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_write_caseclass_to_df" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
|
||
spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) | ||
.write | ||
.cassandraFormat(table, ks) | ||
.mode(SaveMode.Append) | ||
.save() | ||
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, | ||
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) | ||
|
||
spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9)))) | ||
.write | ||
.cassandraFormat(table, ks) | ||
.mode(SaveMode.Append) | ||
.save() | ||
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, | ||
Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9))) | ||
|
||
spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12)))) | ||
.write | ||
.cassandraFormat(table, ks) | ||
.mode(SaveMode.Overwrite) | ||
.option("confirm.truncate", value = true) | ||
.save() | ||
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, | ||
Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12))) | ||
} | ||
} | ||
|
||
it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_write_tuple_to_df" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
|
||
spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) | ||
.toDF("id", "v") | ||
.write | ||
.cassandraFormat(table, ks) | ||
.mode(SaveMode.Append) | ||
.save() | ||
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, | ||
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) | ||
} | ||
} | ||
|
||
it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
|
||
spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) | ||
.saveToCassandra(ks, table) | ||
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, | ||
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) | ||
} | ||
} | ||
|
||
it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_write_tuple_to_rdd" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
|
||
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) | ||
.saveToCassandra(ks, table) | ||
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, | ||
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) | ||
} | ||
} | ||
|
||
it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_read_caseclass_from_df" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
} | ||
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) | ||
.saveToCassandra(ks, table) | ||
|
||
import spark.implicits._ | ||
spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs | ||
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) | ||
} | ||
|
||
it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_read_tuple_from_df" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
} | ||
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) | ||
.saveToCassandra(ks, table) | ||
|
||
import spark.implicits._ | ||
spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs | ||
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) | ||
} | ||
|
||
it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
} | ||
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) | ||
.saveToCassandra(ks, table) | ||
|
||
spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs | ||
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) | ||
} | ||
|
||
it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_read_tuple_from_rdd" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
} | ||
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) | ||
.saveToCassandra(ks, table) | ||
|
||
spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs | ||
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) | ||
} | ||
|
||
it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) { | ||
val table = s"${typeName.toLowerCase}_read_rows_from_sql" | ||
conn.withSessionDo { session => | ||
createVectorTable(session, table) | ||
} | ||
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) | ||
.saveToCassandra(ks, table) | ||
|
||
import spark.implicits._ | ||
spark.conf.set("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog") | ||
spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs | ||
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) | ||
} | ||
|
||
} | ||
|
||
class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") { | ||
override def vectorFromInts(ints: Int*): Seq[Int] = ints | ||
|
||
override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v) | ||
} | ||
|
||
case class IntVectorItem(id: Int, v: Seq[Int]) | ||
|
||
class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") { | ||
override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong) | ||
|
||
override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v) | ||
} | ||
|
||
case class LongVectorItem(id: Int, v: Seq[Long]) | ||
|
||
class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") { | ||
override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f) | ||
|
||
override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v) | ||
} | ||
|
||
case class FloatVectorItem(id: Int, v: Seq[Float]) | ||
|
||
class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") { | ||
override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d) | ||
|
||
override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v) | ||
} | ||
|
||
case class DoubleVectorItem(id: Int, v: Seq[Double]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use Option instead of Some