Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKC-706: Add basic support for Cassandra vectors #1366

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ trait SparkCassandraITSpecBase
}

override def withFixture(test: NoArgTest): Outcome = wrapUnserializableExceptions {
super.withFixture(test)
super.withFixture(test)
}

def getKsName = {
Expand Down Expand Up @@ -147,16 +147,24 @@ trait SparkCassandraITSpecBase

/** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version
* (if this is a DSE cluster) */
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = {
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Some(cassandra), Some(dse))(f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Option instead of Some


def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = {
if (isDse(conn)) {
from(dse)(f)
dse match {
case Some(dseVersion) => from(dseVersion)(f)
case None => report(s"Skipped because not DSE")
}
} else {
from(cassandra)(f)
cassandra match {
case Some(cassandraVersion) => from(cassandraVersion)(f)
case None => report(s"Skipped because not Cassandra")
}
}
}

/** Skips the given test if the Cluster Version is lower or equal to the given version */
def from(version: Version)(f: => Unit): Unit = {
private def from(version: Version)(f: => Unit): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doc is not correct, right? It skips only when the version is lower.

skip(cluster.getCassandraVersion, version) { f }
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.datastax.spark.connector.cql

import com.datastax.spark.connector.SparkCassandraITWordSpecBase
import com.datastax.spark.connector.ccm.CcmConfig
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.types._
import com.datastax.spark.connector.util.schemaFromCassandra
Expand Down Expand Up @@ -49,6 +50,9 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
s"""CREATE INDEX test_d9_m23423ap_idx ON $ks.test (full(d10_set))""")
session.execute(
s"""CREATE INDEX test_d7_int_idx ON $ks.test (d7_int)""")
from(Some(CcmConfig.V5_0_0), None) {
session.execute(s"ALTER TABLE $ks.test ADD d17_vector frozen<vector<int,3>>")
}

for (i <- 0 to 9) {
session.execute(s"insert into $ks.test (k1,k2,k3,c1,c2,c3,d10_set) " +
Expand Down Expand Up @@ -111,8 +115,8 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {

"allow to read regular column definitions" in {
val columns = table.regularColumns
columns.size shouldBe 16
columns.map(_.columnName).toSet shouldBe Set(
columns.size should be >= 16
columns.map(_.columnName).toSet should contain allElementsOf Set(
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
Expand All @@ -136,6 +140,9 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
table.columnByName("d14_varchar").columnType shouldBe VarCharType
table.columnByName("d15_varint").columnType shouldBe VarIntType
table.columnByName("d16_address").columnType shouldBe a [UserDefinedType]
from(Some(CcmConfig.V5_0_0), None) {
table.columnByName("d17_vector").columnType shouldBe VectorType(IntType, 3)
}
}

"allow to list fields of a user defined type" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import com.datastax.oss.driver.api.core.cql.SimpleStatement
import com.datastax.oss.driver.api.core.cql.SimpleStatement._
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V6_7_0, V3_6_0}
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, DSE_V6_7_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf}
import com.datastax.spark.connector.mapper.{DefaultColumnMapper, JavaBeanColumnMapper, JavaTestBean, JavaTestUDTBean}
Expand Down Expand Up @@ -279,7 +279,7 @@
executor.execute(newInstance( s"""CREATE TABLE $ks.big_table (key INT PRIMARY KEY, value INT)"""))
val insert = session.prepare( s"""INSERT INTO $ks.big_table(key, value) VALUES (?, ?)""")
awaitAll {
for (k <- (0 until bigTableRowCount).grouped(100); i <- k) yield {

Check failure on line 282 in connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala

View workflow job for this annotation

GitHub Actions / build (2.13.13, 4.1.4)

CassandraRDDSpec.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.AllNodesFailedException: All 1 node(s) tried for the query failed (showing first 1 nodes, use getAllErrors() for more): Node(endPoint=localhost/127.0.0.1:9042, hostId=dce1bf1d-79cc-400a-9054-03200d57af07, hashCode=5cb6ea36): [com.datastax.oss.driver.api.core.NodeUnavailableException: No connection was available to Node(endPoint=localhost/127.0.0.1:9042, hostId=dce1bf1d-79cc-400a-9054-03200d57af07, hashCode=5cb6ea36)]
executor.executeAsync(insert.bind(i.asInstanceOf[AnyRef], i.asInstanceOf[AnyRef]))
}
}
Expand Down Expand Up @@ -794,7 +794,7 @@
results should contain ((KeyGroup(3, 300), (3, 300, "0003")))
}

it should "allow the use of PER PARTITION LIMITs " in from(V3_6_0) {
it should "allow the use of PER PARTITION LIMITs " in from(cassandra = V3_6_0, dse = DSE_V5_1_0) {
val result = sc.cassandraTable(ks, "clustering_time").perPartitionLimit(1).collect
result.length should be (1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption._
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, BoundStatement}
import com.datastax.oss.driver.api.core.{DefaultConsistencyLevel, DefaultProtocolVersion}
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig.V3_6_0
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.embedded.SparkTemplate._
Expand Down Expand Up @@ -425,7 +425,7 @@ class RDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {

}

it should "should be joinable with a PER PARTITION LIMIT limit" in from(V3_6_0){
it should "should be joinable with a PER PARTITION LIMIT limit" in from(cassandra = V3_6_0, dse = DSE_V5_1_0){
val source = sc.parallelize(keys).map(x => (x, x * 100))
val someCass = source
.joinWithCassandraTable(ks, wideTable, joinColumns = SomeColumns("key", "group"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
package com.datastax.spark.connector.rdd.typeTests

import com.datastax.oss.driver.api.core.cql.Row
import com.datastax.oss.driver.api.core.{CqlSession, Version}
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.datasource.CassandraCatalog
import com.datastax.spark.connector.mapper.ColumnMapper
import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType}
import com.datastax.spark.connector.rdd.reader.RowReaderFactory
import com.datastax.spark.connector.types.TypeConverter
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper}

import scala.collection.convert.ImplicitConversionsToScala._
import scala.collection.immutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._
import scala.reflect._


abstract class VectorTypeTest[
ScalaType: ClassTag : TypeTag,
DriverType <: Number : ClassTag,
CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster
{
/** Skips the given test if the cluster is not Cassandra */
override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove


override lazy val conn = CassandraConnector(sparkConf)

val VectorTable = "vectors"

def createVectorTable(session: CqlSession, table: String): Unit = {
session.execute(
s"""CREATE TABLE IF NOT EXISTS $ks.$table (
| id INT PRIMARY KEY,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a vector be a primary key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, why do you ask?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if that's supported.

| v VECTOR<$typeName, 3>
|)""".stripMargin)
}

def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1"))

def minDSEVersion: Option[Version] = None

def vectorFromInts(ints: Int*): Seq[ScalaType]

def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType

override def beforeClass() {
conn.withSessionDo { session =>
session.execute(
s"""CREATE KEYSPACE IF NOT EXISTS $ks
|WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }"""
.stripMargin)
}
}

private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assertVectors would be more clear as the method doesn't return boolean, instead it asserts on the content.

val returnedVectors = for (i <- expectedVectors.indices) yield {
rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq
}

returnedVectors should contain theSameElementsInOrderAs expectedVectors
}

"SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_caseclass_to_df"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Append)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))

spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9))))
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Append)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9)))

spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12))))
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Overwrite)
.option("confirm.truncate", value = true)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12)))
}
}

it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_tuple_to_df"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.toDF("id", "v")
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Append)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
}
}

it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
}
}

it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_tuple_to_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
}
}

it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_caseclass_from_df"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._
spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
}

it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_tuple_from_df"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._
spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
}

it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
}

it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_tuple_from_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
}

it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_rows_from_sql"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._
spark.conf.set("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog")
spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
}

}

class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") {
override def vectorFromInts(ints: Int*): Seq[Int] = ints

override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v)
}

case class IntVectorItem(id: Int, v: Seq[Int])

class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") {
override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong)

override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v)
}

case class LongVectorItem(id: Int, v: Seq[Long])

class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") {
override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f)

override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v)
}

case class FloatVectorItem(id: Int, v: Seq[Float])

class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") {
override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d)

override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v)
}

case class DoubleVectorItem(id: Int, v: Seq[Double])

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource

import java.util.Locale
import com.datastax.oss.driver.api.core.ProtocolVersion
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType}
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType}
import com.datastax.oss.driver.api.core.`type`.DataTypes._
import com.datastax.dse.driver.api.core.`type`.DseDataTypes._
import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata}
Expand Down Expand Up @@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging {
case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable)
case udt: UserDefinedType => fromUdt(udt)
case t: TupleType => fromTuple(t)
case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable)
case VARINT =>
logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.")
primitiveCatalystDataType(cassandraType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ object DataTypeConverter extends Logging {
cassandraType match {
case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable)
case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField))
case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper](
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))

case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))

case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging {
val valueConverter = converter(valueColumnType, valueScalaType)
TypeConverter.javaHashMapConverter(keyConverter, valueConverter)

case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]])

case (tt @ TupleType(argColumnType1, argColumnType2),
TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) =>
val c1 = converter(argColumnType1.columnType, argScalaType1)
Expand Down
Loading
Loading