Skip to content

Commit 8e55ff0

Browse files
SPARKC-706: Add basic support for Cassandra vectors
1 parent 6c6ce1b commit 8e55ff0

File tree

12 files changed

+337
-15
lines changed

12 files changed

+337
-15
lines changed

connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala

+25-6
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,17 @@ trait SparkCassandraITSpecBase
9797
restoreSystemProps()
9898
}
9999

100+
private val currentTest: ThreadLocal[Option[String]] = new ThreadLocal[Option[String]] {
101+
override def initialValue(): Option[String] = None
102+
}
103+
100104
override def withFixture(test: NoArgTest): Outcome = wrapUnserializableExceptions {
101-
super.withFixture(test)
105+
currentTest.set(Some(test.name))
106+
try {
107+
super.withFixture(test)
108+
} finally {
109+
currentTest.set(None)
110+
}
102111
}
103112

104113
def getKsName = {
@@ -131,7 +140,10 @@ trait SparkCassandraITSpecBase
131140

132141
def pv = conn.withSessionDo(_.getContext.getProtocolVersion)
133142

134-
def report(message: String): Unit = alert(message)
143+
def report(message: String): Unit = {
144+
if (currentTest.get().isDefined)
145+
cancel(message)
146+
}
135147

136148
val ks = getKsName
137149

@@ -147,16 +159,23 @@ trait SparkCassandraITSpecBase
147159

148160
/** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version
149161
* (if this is a DSE cluster) */
150-
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = {
162+
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Some(cassandra), Some(dse))(f)
163+
def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = {
151164
if (isDse(conn)) {
152-
from(dse)(f)
165+
dse match {
166+
case Some(dseVersion) => from(dseVersion)(f)
167+
case None => report(s"Skipped because not DSE")
168+
}
153169
} else {
154-
from(cassandra)(f)
170+
cassandra match {
171+
case Some(cassandraVersion) => from(cassandraVersion)(f)
172+
case None => report(s"Skipped because not Cassandra")
173+
}
155174
}
156175
}
157176

158177
/** Skips the given test if the Cluster Version is lower or equal to the given version */
159-
def from(version: Version)(f: => Unit): Unit = {
178+
private def from(version: Version)(f: => Unit): Unit = {
160179
skip(cluster.getCassandraVersion, version) { f }
161180
}
162181

connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
4040
| d14_varchar varchar,
4141
| d15_varint varint,
4242
| d16_address frozen<address>,
43+
| d17_vector frozen<vector<int,3>>,
4344
| PRIMARY KEY ((k1, k2, k3), c1, c2, c3)
4445
|)
4546
""".stripMargin)
@@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
111112

112113
"allow to read regular column definitions" in {
113114
val columns = table.regularColumns
114-
columns.size shouldBe 16
115+
columns.size shouldBe 17
115116
columns.map(_.columnName).toSet shouldBe Set(
116117
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
117118
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
118119
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
119-
"d15_varint", "d16_address")
120+
"d15_varint", "d16_address", "d17_vector")
120121
}
121122

122123
"allow to read proper types of columns" in {
@@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
136137
table.columnByName("d14_varchar").columnType shouldBe VarCharType
137138
table.columnByName("d15_varint").columnType shouldBe VarIntType
138139
table.columnByName("d16_address").columnType shouldBe a [UserDefinedType]
140+
table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3)
139141
}
140142

141143
"allow to list fields of a user defined type" in {

connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption
99
import com.datastax.oss.driver.api.core.cql.SimpleStatement
1010
import com.datastax.oss.driver.api.core.cql.SimpleStatement._
1111
import com.datastax.spark.connector._
12-
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V6_7_0, V3_6_0}
12+
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, DSE_V6_7_0, V3_6_0}
1313
import com.datastax.spark.connector.cluster.DefaultCluster
1414
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf}
1515
import com.datastax.spark.connector.mapper.{DefaultColumnMapper, JavaBeanColumnMapper, JavaTestBean, JavaTestUDTBean}
@@ -794,7 +794,7 @@ class CassandraRDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster
794794
results should contain ((KeyGroup(3, 300), (3, 300, "0003")))
795795
}
796796

797-
it should "allow the use of PER PARTITION LIMITs " in from(V3_6_0) {
797+
it should "allow the use of PER PARTITION LIMITs " in from(cassandra = V3_6_0, dse = DSE_V5_1_0) {
798798
val result = sc.cassandraTable(ks, "clustering_time").perPartitionLimit(1).collect
799799
result.length should be (1)
800800
}

connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption._
55
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, BoundStatement}
66
import com.datastax.oss.driver.api.core.{DefaultConsistencyLevel, DefaultProtocolVersion}
77
import com.datastax.spark.connector._
8-
import com.datastax.spark.connector.ccm.CcmConfig.V3_6_0
8+
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, V3_6_0}
99
import com.datastax.spark.connector.cluster.DefaultCluster
1010
import com.datastax.spark.connector.cql.CassandraConnector
1111
import com.datastax.spark.connector.embedded.SparkTemplate._
@@ -425,7 +425,7 @@ class RDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {
425425

426426
}
427427

428-
it should "should be joinable with a PER PARTITION LIMIT limit" in from(V3_6_0){
428+
it should "should be joinable with a PER PARTITION LIMIT limit" in from(cassandra = V3_6_0, dse = DSE_V5_1_0){
429429
val source = sc.parallelize(keys).map(x => (x, x * 100))
430430
val someCass = source
431431
.joinWithCassandraTable(ks, wideTable, joinColumns = SomeColumns("key", "group"))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
package com.datastax.spark.connector.rdd.typeTests
2+
3+
import com.datastax.oss.driver.api.core.cql.Row
4+
import com.datastax.oss.driver.api.core.{CqlSession, Version}
5+
import com.datastax.spark.connector._
6+
import com.datastax.spark.connector.ccm.CcmConfig
7+
import com.datastax.spark.connector.cluster.DefaultCluster
8+
import com.datastax.spark.connector.cql.CassandraConnector
9+
import com.datastax.spark.connector.datasource.CassandraCatalog
10+
import com.datastax.spark.connector.mapper.ColumnMapper
11+
import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType}
12+
import com.datastax.spark.connector.rdd.reader.RowReaderFactory
13+
import com.datastax.spark.connector.types.TypeConverter
14+
import org.apache.spark.sql.{SaveMode, SparkSession}
15+
import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper}
16+
17+
import scala.collection.convert.ImplicitConversionsToScala._
18+
import scala.collection.immutable
19+
import scala.reflect.ClassTag
20+
import scala.reflect.runtime.universe._
21+
import scala.reflect._
22+
23+
24+
abstract class VectorTypeTest[
25+
ScalaType: ClassTag : TypeTag,
26+
DriverType <: Number : ClassTag,
27+
CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster
28+
{
29+
/** Skips the given test if the cluster is not Cassandra */
30+
override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f)
31+
32+
override lazy val conn = CassandraConnector(sparkConf)
33+
34+
val VectorTable = "vectors"
35+
36+
def createVectorTable(session: CqlSession, table: String): Unit = {
37+
session.execute(
38+
s"""CREATE TABLE IF NOT EXISTS $ks.$table (
39+
| id INT PRIMARY KEY,
40+
| v VECTOR<$typeName, 3>
41+
|)""".stripMargin)
42+
}
43+
44+
def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1"))
45+
46+
def minDSEVersion: Option[Version] = None
47+
48+
def vectorFromInts(ints: Int*): Seq[ScalaType]
49+
50+
def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType
51+
52+
override lazy val spark = SparkSession.builder()
53+
.config(sparkConf)
54+
.config("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog")
55+
.withExtensions(new CassandraSparkExtensions).getOrCreate().newSession()
56+
57+
override def beforeClass() {
58+
conn.withSessionDo { session =>
59+
session.execute(
60+
s"""CREATE KEYSPACE IF NOT EXISTS $ks
61+
|WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }"""
62+
.stripMargin)
63+
}
64+
}
65+
66+
private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = {
67+
val returnedVectors = for (i <- expectedVectors.indices) yield {
68+
rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq
69+
}
70+
71+
returnedVectors should contain theSameElementsInOrderAs expectedVectors
72+
}
73+
74+
"SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
75+
val table = s"${typeName.toLowerCase}_write_caseclass_to_df"
76+
conn.withSessionDo { session =>
77+
createVectorTable(session, table)
78+
79+
spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
80+
.write
81+
.cassandraFormat(table, ks)
82+
.mode(SaveMode.Append)
83+
.save()
84+
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
85+
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
86+
87+
spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9))))
88+
.write
89+
.cassandraFormat(table, ks)
90+
.mode(SaveMode.Append)
91+
.save()
92+
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
93+
Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9)))
94+
95+
spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12))))
96+
.write
97+
.cassandraFormat(table, ks)
98+
.mode(SaveMode.Overwrite)
99+
.option("confirm.truncate", value = true)
100+
.save()
101+
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
102+
Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12)))
103+
}
104+
}
105+
106+
it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
107+
val table = s"${typeName.toLowerCase}_write_tuple_to_df"
108+
conn.withSessionDo { session =>
109+
createVectorTable(session, table)
110+
111+
spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
112+
.toDF("id", "v")
113+
.write
114+
.cassandraFormat(table, ks)
115+
.mode(SaveMode.Append)
116+
.save()
117+
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
118+
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
119+
}
120+
}
121+
122+
it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
123+
val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd"
124+
conn.withSessionDo { session =>
125+
createVectorTable(session, table)
126+
127+
spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
128+
.saveToCassandra(ks, table)
129+
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
130+
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
131+
}
132+
}
133+
134+
it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
135+
val table = s"${typeName.toLowerCase}_write_tuple_to_rdd"
136+
conn.withSessionDo { session =>
137+
createVectorTable(session, table)
138+
139+
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
140+
.saveToCassandra(ks, table)
141+
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
142+
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
143+
}
144+
}
145+
146+
it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
147+
val table = s"${typeName.toLowerCase}_read_caseclass_from_df"
148+
conn.withSessionDo { session =>
149+
createVectorTable(session, table)
150+
}
151+
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
152+
.saveToCassandra(ks, table)
153+
154+
import spark.implicits._
155+
spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs
156+
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
157+
}
158+
159+
it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
160+
val table = s"${typeName.toLowerCase}_read_tuple_from_df"
161+
conn.withSessionDo { session =>
162+
createVectorTable(session, table)
163+
}
164+
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
165+
.saveToCassandra(ks, table)
166+
167+
import spark.implicits._
168+
spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
169+
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
170+
}
171+
172+
it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
173+
val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd"
174+
conn.withSessionDo { session =>
175+
createVectorTable(session, table)
176+
}
177+
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
178+
.saveToCassandra(ks, table)
179+
180+
spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs
181+
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
182+
}
183+
184+
it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
185+
val table = s"${typeName.toLowerCase}_read_tuple_from_rdd"
186+
conn.withSessionDo { session =>
187+
createVectorTable(session, table)
188+
}
189+
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
190+
.saveToCassandra(ks, table)
191+
192+
spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs
193+
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
194+
}
195+
196+
it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) {
197+
val table = s"${typeName.toLowerCase}_read_rows_from_sql"
198+
conn.withSessionDo { session =>
199+
createVectorTable(session, table)
200+
}
201+
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
202+
.saveToCassandra(ks, table)
203+
204+
import spark.implicits._
205+
spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
206+
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
207+
}
208+
209+
}
210+
211+
class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") {
212+
override def vectorFromInts(ints: Int*): Seq[Int] = ints
213+
214+
override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v)
215+
}
216+
217+
case class IntVectorItem(id: Int, v: Seq[Int])
218+
219+
class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") {
220+
override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong)
221+
222+
override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v)
223+
}
224+
225+
case class LongVectorItem(id: Int, v: Seq[Long])
226+
227+
class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") {
228+
override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f)
229+
230+
override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v)
231+
}
232+
233+
case class FloatVectorItem(id: Int, v: Seq[Float])
234+
235+
class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") {
236+
override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d)
237+
238+
override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v)
239+
}
240+
241+
case class DoubleVectorItem(id: Int, v: Seq[Double])
242+

connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource
22

33
import java.util.Locale
44
import com.datastax.oss.driver.api.core.ProtocolVersion
5-
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType}
5+
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType}
66
import com.datastax.oss.driver.api.core.`type`.DataTypes._
77
import com.datastax.dse.driver.api.core.`type`.DseDataTypes._
88
import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata}
@@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging {
167167
case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable)
168168
case udt: UserDefinedType => fromUdt(udt)
169169
case t: TupleType => fromTuple(t)
170+
case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable)
170171
case VARINT =>
171172
logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.")
172173
primitiveCatalystDataType(cassandraType)

connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ object DataTypeConverter extends Logging {
5959
cassandraType match {
6060
case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
6161
case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
62+
case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
6263
case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable)
6364
case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField))
6465
case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple))

0 commit comments

Comments
 (0)