|
| 1 | +package com.datastax.spark.connector.rdd.typeTests |
| 2 | + |
| 3 | +import com.datastax.oss.driver.api.core.cql.Row |
| 4 | +import com.datastax.oss.driver.api.core.{CqlSession, Version} |
| 5 | +import com.datastax.spark.connector._ |
| 6 | +import com.datastax.spark.connector.ccm.CcmConfig |
| 7 | +import com.datastax.spark.connector.cluster.DefaultCluster |
| 8 | +import com.datastax.spark.connector.cql.CassandraConnector |
| 9 | +import com.datastax.spark.connector.datasource.CassandraCatalog |
| 10 | +import com.datastax.spark.connector.mapper.ColumnMapper |
| 11 | +import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType} |
| 12 | +import com.datastax.spark.connector.rdd.reader.RowReaderFactory |
| 13 | +import com.datastax.spark.connector.types.TypeConverter |
| 14 | +import org.apache.spark.sql.{SaveMode, SparkSession} |
| 15 | +import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper} |
| 16 | + |
| 17 | +import scala.collection.convert.ImplicitConversionsToScala._ |
| 18 | +import scala.collection.immutable |
| 19 | +import scala.reflect.ClassTag |
| 20 | +import scala.reflect.runtime.universe._ |
| 21 | +import scala.reflect._ |
| 22 | + |
| 23 | + |
| 24 | +abstract class VectorTypeTest[ |
| 25 | + ScalaType: ClassTag : TypeTag, |
| 26 | + DriverType <: Number : ClassTag, |
| 27 | + CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster |
| 28 | +{ |
| 29 | + /** Skips the given test if the cluster is not Cassandra */ |
| 30 | + override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f) |
| 31 | + |
| 32 | + override lazy val conn = CassandraConnector(sparkConf) |
| 33 | + |
| 34 | + val VectorTable = "vectors" |
| 35 | + |
| 36 | + def createVectorTable(session: CqlSession, table: String): Unit = { |
| 37 | + session.execute( |
| 38 | + s"""CREATE TABLE IF NOT EXISTS $ks.$table ( |
| 39 | + | id INT PRIMARY KEY, |
| 40 | + | v VECTOR<$typeName, 3> |
| 41 | + |)""".stripMargin) |
| 42 | + } |
| 43 | + |
| 44 | + def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1")) |
| 45 | + |
| 46 | + def minDSEVersion: Option[Version] = None |
| 47 | + |
| 48 | + def vectorFromInts(ints: Int*): Seq[ScalaType] |
| 49 | + |
| 50 | + def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType |
| 51 | + |
| 52 | + override lazy val spark = SparkSession.builder() |
| 53 | + .config(sparkConf) |
| 54 | + .config("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog") |
| 55 | + .withExtensions(new CassandraSparkExtensions).getOrCreate().newSession() |
| 56 | + |
| 57 | + override def beforeClass() { |
| 58 | + conn.withSessionDo { session => |
| 59 | + session.execute( |
| 60 | + s"""CREATE KEYSPACE IF NOT EXISTS $ks |
| 61 | + |WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }""" |
| 62 | + .stripMargin) |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = { |
| 67 | + val returnedVectors = for (i <- expectedVectors.indices) yield { |
| 68 | + rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq |
| 69 | + } |
| 70 | + |
| 71 | + returnedVectors should contain theSameElementsInOrderAs expectedVectors |
| 72 | + } |
| 73 | + |
| 74 | + "SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) { |
| 75 | + val table = s"${typeName.toLowerCase}_write_caseclass_to_df" |
| 76 | + conn.withSessionDo { session => |
| 77 | + createVectorTable(session, table) |
| 78 | + |
| 79 | + spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) |
| 80 | + .write |
| 81 | + .cassandraFormat(table, ks) |
| 82 | + .mode(SaveMode.Append) |
| 83 | + .save() |
| 84 | + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, |
| 85 | + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) |
| 86 | + |
| 87 | + spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9)))) |
| 88 | + .write |
| 89 | + .cassandraFormat(table, ks) |
| 90 | + .mode(SaveMode.Append) |
| 91 | + .save() |
| 92 | + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, |
| 93 | + Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9))) |
| 94 | + |
| 95 | + spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12)))) |
| 96 | + .write |
| 97 | + .cassandraFormat(table, ks) |
| 98 | + .mode(SaveMode.Overwrite) |
| 99 | + .option("confirm.truncate", value = true) |
| 100 | + .save() |
| 101 | + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, |
| 102 | + Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12))) |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { |
| 107 | + val table = s"${typeName.toLowerCase}_write_tuple_to_df" |
| 108 | + conn.withSessionDo { session => |
| 109 | + createVectorTable(session, table) |
| 110 | + |
| 111 | + spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) |
| 112 | + .toDF("id", "v") |
| 113 | + .write |
| 114 | + .cassandraFormat(table, ks) |
| 115 | + .mode(SaveMode.Append) |
| 116 | + .save() |
| 117 | + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, |
| 118 | + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { |
| 123 | + val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd" |
| 124 | + conn.withSessionDo { session => |
| 125 | + createVectorTable(session, table) |
| 126 | + |
| 127 | + spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) |
| 128 | + .saveToCassandra(ks, table) |
| 129 | + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, |
| 130 | + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { |
| 135 | + val table = s"${typeName.toLowerCase}_write_tuple_to_rdd" |
| 136 | + conn.withSessionDo { session => |
| 137 | + createVectorTable(session, table) |
| 138 | + |
| 139 | + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) |
| 140 | + .saveToCassandra(ks, table) |
| 141 | + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, |
| 142 | + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { |
| 147 | + val table = s"${typeName.toLowerCase}_read_caseclass_from_df" |
| 148 | + conn.withSessionDo { session => |
| 149 | + createVectorTable(session, table) |
| 150 | + } |
| 151 | + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) |
| 152 | + .saveToCassandra(ks, table) |
| 153 | + |
| 154 | + import spark.implicits._ |
| 155 | + spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs |
| 156 | + Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) |
| 157 | + } |
| 158 | + |
| 159 | + it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { |
| 160 | + val table = s"${typeName.toLowerCase}_read_tuple_from_df" |
| 161 | + conn.withSessionDo { session => |
| 162 | + createVectorTable(session, table) |
| 163 | + } |
| 164 | + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) |
| 165 | + .saveToCassandra(ks, table) |
| 166 | + |
| 167 | + import spark.implicits._ |
| 168 | + spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs |
| 169 | + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) |
| 170 | + } |
| 171 | + |
| 172 | + it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { |
| 173 | + val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd" |
| 174 | + conn.withSessionDo { session => |
| 175 | + createVectorTable(session, table) |
| 176 | + } |
| 177 | + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) |
| 178 | + .saveToCassandra(ks, table) |
| 179 | + |
| 180 | + spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs |
| 181 | + Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) |
| 182 | + } |
| 183 | + |
| 184 | + it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { |
| 185 | + val table = s"${typeName.toLowerCase}_read_tuple_from_rdd" |
| 186 | + conn.withSessionDo { session => |
| 187 | + createVectorTable(session, table) |
| 188 | + } |
| 189 | + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) |
| 190 | + .saveToCassandra(ks, table) |
| 191 | + |
| 192 | + spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs |
| 193 | + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) |
| 194 | + } |
| 195 | + |
| 196 | + it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) { |
| 197 | + val table = s"${typeName.toLowerCase}_read_rows_from_sql" |
| 198 | + conn.withSessionDo { session => |
| 199 | + createVectorTable(session, table) |
| 200 | + } |
| 201 | + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) |
| 202 | + .saveToCassandra(ks, table) |
| 203 | + |
| 204 | + import spark.implicits._ |
| 205 | + spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs |
| 206 | + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) |
| 207 | + } |
| 208 | + |
| 209 | +} |
| 210 | + |
| 211 | +class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") { |
| 212 | + override def vectorFromInts(ints: Int*): Seq[Int] = ints |
| 213 | + |
| 214 | + override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v) |
| 215 | +} |
| 216 | + |
| 217 | +case class IntVectorItem(id: Int, v: Seq[Int]) |
| 218 | + |
| 219 | +class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") { |
| 220 | + override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong) |
| 221 | + |
| 222 | + override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v) |
| 223 | +} |
| 224 | + |
| 225 | +case class LongVectorItem(id: Int, v: Seq[Long]) |
| 226 | + |
| 227 | +class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") { |
| 228 | + override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f) |
| 229 | + |
| 230 | + override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v) |
| 231 | +} |
| 232 | + |
| 233 | +case class FloatVectorItem(id: Int, v: Seq[Float]) |
| 234 | + |
| 235 | +class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") { |
| 236 | + override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d) |
| 237 | + |
| 238 | + override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v) |
| 239 | +} |
| 240 | + |
| 241 | +case class DoubleVectorItem(id: Int, v: Seq[Double]) |
| 242 | + |
0 commit comments