diff --git a/build.sbt b/build.sbt index 6b9c8dd5..489cccf3 100644 --- a/build.sbt +++ b/build.sbt @@ -52,9 +52,10 @@ lazy val migrator = (project in file("migrator")).enablePlugins(BuildInfoPlugin) "com.github.jnr" % "jnr-posix" % "3.1.19", // Needed by the Spark ScyllaDB connector "com.scylladb.alternator" % "emr-dynamodb-hadoop" % "5.8.0", "com.scylladb.alternator" % "load-balancing" % "1.0.0", - "io.circe" %% "circe-generic" % "0.14.7", - "io.circe" %% "circe-parser" % "0.14.7", - "io.circe" %% "circe-yaml" % "0.15.1", + "io.circe" %% "circe-generic" % "0.14.7", + "io.circe" %% "circe-parser" % "0.14.7", + "io.circe" %% "circe-yaml" % "0.15.1", + "io.circe" %% "circe-generic-extras" % "0.14.4", ), assembly / assemblyShadeRules := Seq( ShadeRule.rename("org.yaml.snakeyaml.**" -> "com.scylladb.shaded.@1").inAll diff --git a/docker-compose-tests.yml b/docker-compose-tests.yml index 948854ec..93f51dce 100644 --- a/docker-compose-tests.yml +++ b/docker-compose-tests.yml @@ -27,8 +27,12 @@ services: - "9044:9042" expose: - 9044 - command: "--smp 1 --memory 2048M --alternator-port 8000 --alternator-write-isolation only_rmw_uses_lwt" - + command: > + --smp 1 + --memory 2048M + --alternator-port 8000 + --alternator-write-isolation only_rmw_uses_lwt + --tablets-mode-for-new-keyspaces disabled scylla: image: scylladb/scylla:latest volumes: @@ -36,7 +40,12 @@ services: ports: - "8000:8000" - "9042:9042" - command: "--smp 1 --memory 2048M --alternator-port 8000 --alternator-write-isolation only_rmw_uses_lwt" + command: > + --smp 1 + --memory 2048M + --alternator-port 8000 + --alternator-write-isolation only_rmw_uses_lwt + --tablets-mode-for-new-keyspaces disabled s3: image: localstack/localstack:s3-latest diff --git a/migrator/src/main/scala/com/scylladb/migrator/Migrator.scala b/migrator/src/main/scala/com/scylladb/migrator/Migrator.scala index 8829761f..bd1e1fa8 100644 --- a/migrator/src/main/scala/com/scylladb/migrator/Migrator.scala +++ b/migrator/src/main/scala/com/scylladb/migrator/Migrator.scala @@ -40,8 +40,7 @@ object Migrator { migratorConfig.getSkipTokenRangesOrEmptySet) ScyllaMigrator.migrate(migratorConfig, scyllaTarget, sourceDF) case (parquetSource: SourceSettings.Parquet, scyllaTarget: TargetSettings.Scylla) => - val sourceDF = readers.Parquet.readDataFrame(spark, parquetSource) - ScyllaMigrator.migrate(migratorConfig, scyllaTarget, sourceDF) + readers.Parquet.migrateToScylla(migratorConfig, parquetSource, scyllaTarget)(spark) case (dynamoSource: SourceSettings.DynamoDB, alternatorTarget: TargetSettings.DynamoDB) => AlternatorMigrator.migrateFromDynamoDB(dynamoSource, alternatorTarget, migratorConfig) case ( diff --git a/migrator/src/main/scala/com/scylladb/migrator/alternator/StringSetAccumulator.scala b/migrator/src/main/scala/com/scylladb/migrator/alternator/StringSetAccumulator.scala new file mode 100644 index 00000000..64858747 --- /dev/null +++ b/migrator/src/main/scala/com/scylladb/migrator/alternator/StringSetAccumulator.scala @@ -0,0 +1,38 @@ +package com.scylladb.migrator.alternator + +import org.apache.spark.util.AccumulatorV2 +import java.util.concurrent.atomic.AtomicReference + +/** + * Accumulator for tracking processed Parquet file paths during migration. + * + * This accumulator collects the set of Parquet file paths that have been processed + * as part of a migration job. It is useful for monitoring progress, avoiding duplicate + * processing, and debugging migration workflows. The accumulator is thread-safe and + * can be used in distributed Spark jobs. + * + * @param initialValue The initial set of processed file paths (usually empty). + */ +class StringSetAccumulator(initialValue: Set[String] = Set.empty) + extends AccumulatorV2[String, Set[String]] { + + private val ref = new AtomicReference(initialValue) + + // Note: isZero may be momentarily inconsistent in concurrent scenarios, + // as it reads the current value of the set without synchronization. + // This is eventually consistent and thread-safe, but may not reflect the most recent updates. + def isZero: Boolean = ref.get.isEmpty + def copy(): StringSetAccumulator = new StringSetAccumulator(ref.get) + def reset(): Unit = ref.set(Set.empty) + def add(v: String): Unit = ref.getAndUpdate(_ + v) + + def merge(other: AccumulatorV2[String, Set[String]]): Unit = + ref.getAndUpdate(_ ++ other.value) + + def value: Set[String] = ref.get +} + +object StringSetAccumulator { + def apply(initialValue: Set[String] = Set.empty): StringSetAccumulator = + new StringSetAccumulator(initialValue) +} diff --git a/migrator/src/main/scala/com/scylladb/migrator/config/MigratorConfig.scala b/migrator/src/main/scala/com/scylladb/migrator/config/MigratorConfig.scala index 7768abd5..2aac0e6c 100644 --- a/migrator/src/main/scala/com/scylladb/migrator/config/MigratorConfig.scala +++ b/migrator/src/main/scala/com/scylladb/migrator/config/MigratorConfig.scala @@ -14,6 +14,7 @@ case class MigratorConfig(source: SourceSettings, savepoints: Savepoints, skipTokenRanges: Option[Set[(Token[_], Token[_])]], skipSegments: Option[Set[Int]], + skipParquetFiles: Option[Set[String]], validation: Option[Validation]) { def render: String = this.asJson.asYaml.spaces2 @@ -25,6 +26,8 @@ case class MigratorConfig(source: SourceSettings, def getSkipTokenRangesOrEmptySet: Set[(Token[_], Token[_])] = skipTokenRanges.getOrElse(Set.empty) + def getSkipParquetFilesOrEmptySet: Set[String] = skipParquetFiles.getOrElse(Set.empty) + } object MigratorConfig { implicit val tokenEncoder: Encoder[Token[_]] = Encoder.instance { diff --git a/migrator/src/main/scala/com/scylladb/migrator/config/Savepoints.scala b/migrator/src/main/scala/com/scylladb/migrator/config/Savepoints.scala index b12fac6f..5b04c0be 100644 --- a/migrator/src/main/scala/com/scylladb/migrator/config/Savepoints.scala +++ b/migrator/src/main/scala/com/scylladb/migrator/config/Savepoints.scala @@ -1,10 +1,12 @@ package com.scylladb.migrator.config -import io.circe.{ Decoder, Encoder } -import io.circe.generic.semiauto.{ deriveDecoder, deriveEncoder } +import io.circe.Codec +import io.circe.generic.extras.Configuration +import io.circe.generic.extras.semiauto._ + +case class Savepoints(intervalSeconds: Int, path: String, enableParquetFileTracking: Boolean = true) -case class Savepoints(intervalSeconds: Int, path: String) object Savepoints { - implicit val encoder: Encoder[Savepoints] = deriveEncoder[Savepoints] - implicit val decoder: Decoder[Savepoints] = deriveDecoder[Savepoints] + implicit val config: Configuration = Configuration.default.withDefaults + implicit val codec: Codec[Savepoints] = deriveConfiguredCodec[Savepoints] } diff --git a/migrator/src/main/scala/com/scylladb/migrator/readers/FileCompletionListener.scala b/migrator/src/main/scala/com/scylladb/migrator/readers/FileCompletionListener.scala new file mode 100644 index 00000000..66c8575e --- /dev/null +++ b/migrator/src/main/scala/com/scylladb/migrator/readers/FileCompletionListener.scala @@ -0,0 +1,93 @@ +package com.scylladb.migrator.readers + +import org.apache.log4j.LogManager +import org.apache.spark.scheduler.{ SparkListener, SparkListenerTaskEnd } +import org.apache.spark.Success + +import scala.collection.concurrent.TrieMap + +/** + * SparkListener that tracks partition completion and aggregates it to file-level completion. + * + * This listener monitors Spark task completion events and maintains mappings between + * partitions and files. When all partitions belonging to a file have been successfully + * completed, it marks the file as processed via the ParquetSavepointsManager. + * + * @param partitionToFiles Mapping from Spark partition ID to source file paths + * @param fileToPartitions Mapping from file path to the set of partition IDs reading from it + * @param savepointsManager Manager to notify when files are completed + */ +class FileCompletionListener( + partitionToFiles: Map[Int, Set[String]], + fileToPartitions: Map[String, Set[Int]], + savepointsManager: ParquetSavepointsManager +) extends SparkListener { + + private val log = LogManager.getLogger("com.scylladb.migrator.readers.FileCompletionListener") + + private val completedPartitions = TrieMap.empty[Int, Boolean] + + private val completedFiles = TrieMap.empty[String, Boolean] + + log.info( + s"FileCompletionListener initialized: tracking ${fileToPartitions.size} files " + + s"across ${partitionToFiles.size} partitions") + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = + if (taskEnd.reason == Success) { + val partitionId = taskEnd.taskInfo.partitionId + + partitionToFiles.get(partitionId) match { + case Some(filenames) => + if (completedPartitions.putIfAbsent(partitionId, true).isEmpty) { + filenames.foreach { filename => + log.debug(s"Partition $partitionId completed (file: $filename)") + checkFileCompletion(filename) + } + } + + case None => + log.trace(s"Task completed for untracked partition $partitionId") + } + } else { + log.debug( + s"Task for partition ${taskEnd.taskInfo.partitionId} did not complete successfully: ${taskEnd.reason}") + } + + private def checkFileCompletion(filename: String): Unit = { + if (completedFiles.contains(filename)) { + return + } + + fileToPartitions.get(filename) match { + case Some(allPartitions) => + val allComplete = allPartitions.forall(completedPartitions.contains) + + if (allComplete) { + if (completedFiles.putIfAbsent(filename, true).isEmpty) { + savepointsManager.markFileAsProcessed(filename) + + val progress = s"${completedFiles.size}/${fileToPartitions.size}" + log.info(s"File completed: $filename (progress: $progress)") + } + } else { + val completedCount = allPartitions.count(completedPartitions.contains) + log.trace(s"File $filename: $completedCount/${allPartitions.size} partitions complete") + } + + case None => + log.warn(s"File $filename not found in fileToPartitions map (this shouldn't happen)") + } + } + + def getCompletedFilesCount: Int = completedFiles.size + + def getTotalFilesCount: Int = fileToPartitions.size + + def getProgressReport: String = { + val filesCompleted = getCompletedFilesCount + val totalFiles = getTotalFilesCount + + s"Progress: $filesCompleted/$totalFiles files" + } +} diff --git a/migrator/src/main/scala/com/scylladb/migrator/readers/Parquet.scala b/migrator/src/main/scala/com/scylladb/migrator/readers/Parquet.scala index 3d473eec..4a20b9a5 100644 --- a/migrator/src/main/scala/com/scylladb/migrator/readers/Parquet.scala +++ b/migrator/src/main/scala/com/scylladb/migrator/readers/Parquet.scala @@ -1,14 +1,151 @@ package com.scylladb.migrator.readers -import com.scylladb.migrator.config.SourceSettings -import com.scylladb.migrator.scylla.SourceDataFrame +import com.scylladb.migrator.config.{ MigratorConfig, SourceSettings, TargetSettings } +import com.scylladb.migrator.scylla.{ ScyllaMigrator, ScyllaParquetMigrator, SourceDataFrame } import org.apache.log4j.LogManager -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{ AnalysisException, SparkSession } +import scala.util.Using object Parquet { val log = LogManager.getLogger("com.scylladb.migrator.readers.Parquet") - def readDataFrame(spark: SparkSession, source: SourceSettings.Parquet): SourceDataFrame = { + def migrateToScylla(config: MigratorConfig, + source: SourceSettings.Parquet, + target: TargetSettings.Scylla)(implicit spark: SparkSession): Unit = { + + val useFileTracking = config.savepoints.enableParquetFileTracking + + if (useFileTracking) { + log.info( + "Starting Parquet migration with file-level savepoint tracking") + migrateWithSavepoints(config, source, target) + } else { + log.info("Starting Parquet migration without savepoint tracking") + migrateWithoutSavepoints(config, source, target) + } + } + + /** + * Parquet migration with file-level savepoint tracking. + * + * This mode tracks completion of individual Parquet files, enabling resume capability + * if the migration is interrupted. Uses SparkListener to detect when all partitions + * of a file have been processed. + */ + private def migrateWithSavepoints( + config: MigratorConfig, + source: SourceSettings.Parquet, + target: TargetSettings.Scylla)(implicit spark: SparkSession): Unit = { + configureHadoopCredentials(spark, source) + + val allFiles = listParquetFiles(spark, source.path) + val skipFiles = config.getSkipParquetFilesOrEmptySet + val filesToProcess = allFiles.filterNot(skipFiles.contains) + + if (filesToProcess.isEmpty) { + log.info("No Parquet files to process. Migration is complete.") + return + } + + log.info(s"Processing ${filesToProcess.size} Parquet files") + + val df = if (skipFiles.isEmpty) { + spark.read.parquet(source.path) + } else { + spark.read.parquet(filesToProcess: _*) + } + + log.info("Reading partition metadata for file tracking...") + val metadata = PartitionMetadataReader.readMetadataFromDataFrame(df) + + val partitionToFiles = PartitionMetadataReader.buildPartitionToFileMap(metadata) + val fileToPartitions = PartitionMetadataReader.buildFileToPartitionsMap(metadata) + + log.info( + s"Discovered ${fileToPartitions.size} files with ${metadata.size} total partitions to process") + + Using.resource(ParquetSavepointsManager(config, spark.sparkContext)) { savepointsManager => + val listener = new FileCompletionListener( + partitionToFiles, + fileToPartitions, + savepointsManager + ) + spark.sparkContext.addSparkListener(listener) + + try { + val sourceDF = SourceDataFrame(df, None, savepointsSupported = false) + + log.info("Created DataFrame from Parquet source") + + ScyllaParquetMigrator.migrate(config, target, sourceDF, savepointsManager) + + savepointsManager.dumpMigrationState("completed") + + log.info( + s"Parquet migration completed successfully: " + + s"${listener.getCompletedFilesCount}/${listener.getTotalFilesCount} files processed") + + } finally { + spark.sparkContext.removeSparkListener(listener) + log.info(s"Final progress: ${listener.getProgressReport}") + } + } + } + + /** + * Parquet migration without savepoint tracking. + * + * This mode reads all Parquet files using Spark's native parallelism but does not + * track individual file completion. If migration is interrupted, it will restart + * from the beginning. + */ + private def migrateWithoutSavepoints( + config: MigratorConfig, + source: SourceSettings.Parquet, + target: TargetSettings.Scylla)(implicit spark: SparkSession): Unit = { + val sourceDF = ParquetWithoutSavepoints.readDataFrame(spark, source) + ScyllaMigrator.migrate(config, target, sourceDF) + } + + def listParquetFiles(spark: SparkSession, path: String): Seq[String] = { + log.info(s"Discovering Parquet files in $path") + + try { + val dataFrame = spark.read + .option("recursiveFileLookup", "true") + .parquet(path) + + val files = dataFrame.inputFiles.toSeq.distinct.sorted + + if (files.isEmpty) { + throw new IllegalArgumentException(s"No Parquet files found in $path") + } + + log.info(s"Found ${files.size} Parquet file(s)") + files + } catch { + case e: AnalysisException => + val message = s"Failed to list Parquet files from $path: ${e.getMessage}" + log.error(message) + throw new IllegalArgumentException(message, e) + } + } + + /** + * Configures Hadoop S3A credentials for reading from AWS S3. + * + * This method sets the necessary Hadoop configuration properties for AWS access key, secret key, + * and optionally a session token. When a session token is present, it sets the credentials provider + * to TemporaryAWSCredentialsProvider as required by Hadoop. + * + * If a region is specified in the source configuration, this method also sets the S3A endpoint region + * via the `fs.s3a.endpoint.region` property. + * + * For more details, see the official Hadoop AWS documentation: + * https://hadoop.apache.org/docs/stable/hadoop-aws/tools/hadoop-aws/index.html#Authentication + */ + private[readers] def configureHadoopCredentials(spark: SparkSession, + source: SourceSettings.Parquet): Unit = source.finalCredentials.foreach { credentials => log.info("Loaded AWS credentials from config file") source.region.foreach { region => @@ -16,7 +153,6 @@ object Parquet { } spark.sparkContext.hadoopConfiguration.set("fs.s3a.access.key", credentials.accessKey) spark.sparkContext.hadoopConfiguration.set("fs.s3a.secret.key", credentials.secretKey) - // See https://hadoop.apache.org/docs/stable/hadoop-aws/tools/hadoop-aws/index.html#Using_Session_Credentials_with_TemporaryAWSCredentialsProvider credentials.maybeSessionToken.foreach { sessionToken => spark.sparkContext.hadoopConfiguration.set( "fs.s3a.aws.credentials.provider", @@ -27,8 +163,4 @@ object Parquet { ) } } - - SourceDataFrame(spark.read.parquet(source.path), None, false) - } - } diff --git a/migrator/src/main/scala/com/scylladb/migrator/readers/ParquetSavepointsManager.scala b/migrator/src/main/scala/com/scylladb/migrator/readers/ParquetSavepointsManager.scala new file mode 100644 index 00000000..8895d203 --- /dev/null +++ b/migrator/src/main/scala/com/scylladb/migrator/readers/ParquetSavepointsManager.scala @@ -0,0 +1,36 @@ +package com.scylladb.migrator.readers + +import com.scylladb.migrator.SavepointsManager +import com.scylladb.migrator.config.MigratorConfig +import com.scylladb.migrator.alternator.StringSetAccumulator +import org.apache.spark.SparkContext + +class ParquetSavepointsManager(migratorConfig: MigratorConfig, + filesAccumulator: StringSetAccumulator) + extends SavepointsManager(migratorConfig) { + + def describeMigrationState(): String = { + val processedCount = filesAccumulator.value.size + s"Processed files: $processedCount" + } + + def updateConfigWithMigrationState(): MigratorConfig = + migratorConfig.copy(skipParquetFiles = Some(filesAccumulator.value)) + + def markFileAsProcessed(filePath: String): Unit = { + filesAccumulator.add(filePath) + log.debug(s"Marked file as processed: $filePath") + } +} + +object ParquetSavepointsManager { + + def apply(migratorConfig: MigratorConfig, spark: SparkContext): ParquetSavepointsManager = { + val filesAccumulator = + StringSetAccumulator(migratorConfig.skipParquetFiles.getOrElse(Set.empty)) + + spark.register(filesAccumulator, "processed-parquet-files") + + new ParquetSavepointsManager(migratorConfig, filesAccumulator) + } +} diff --git a/migrator/src/main/scala/com/scylladb/migrator/readers/ParquetWithoutSavepoints.scala b/migrator/src/main/scala/com/scylladb/migrator/readers/ParquetWithoutSavepoints.scala new file mode 100644 index 00000000..4a6e78bf --- /dev/null +++ b/migrator/src/main/scala/com/scylladb/migrator/readers/ParquetWithoutSavepoints.scala @@ -0,0 +1,27 @@ +package com.scylladb.migrator.readers + +import com.scylladb.migrator.config.SourceSettings +import com.scylladb.migrator.scylla.SourceDataFrame +import org.apache.log4j.LogManager +import org.apache.spark.sql.SparkSession + +/** + * Parquet reader implementation without savepoint tracking. + * + * This implementation provides simple Parquet file reading without file-level savepoint tracking. + * Enable via configuration: `savepoints.enableParquetFileTracking = false` + */ +object ParquetWithoutSavepoints { + val log = LogManager.getLogger("com.scylladb.migrator.readers.ParquetWithoutSavepoints") + + def readDataFrame(spark: SparkSession, source: SourceSettings.Parquet): SourceDataFrame = { + log.info(s"Reading Parquet files from ${source.path} (without savepoint tracking)") + + Parquet.configureHadoopCredentials(spark, source) + + val df = spark.read.parquet(source.path) + log.info(s"Loaded Parquet DataFrame with ${df.rdd.getNumPartitions} partitions") + + SourceDataFrame(df, None, savepointsSupported = false) + } +} diff --git a/migrator/src/main/scala/com/scylladb/migrator/readers/PartitionMetadataReader.scala b/migrator/src/main/scala/com/scylladb/migrator/readers/PartitionMetadataReader.scala new file mode 100644 index 00000000..bed63704 --- /dev/null +++ b/migrator/src/main/scala/com/scylladb/migrator/readers/PartitionMetadataReader.scala @@ -0,0 +1,77 @@ +package com.scylladb.migrator.readers + +import org.apache.log4j.LogManager +import org.apache.spark.sql.{ DataFrame, SparkSession } +import org.apache.spark.sql.execution.datasources.PartitionMetadataExtractor + +case class PartitionMetadata( + partitionId: Int, + filename: String +) + +/** + * This reader uses Spark's internal partition information to build mappings + * between partition IDs and file paths. This allows us to track when all + * partitions of a file have been processed, enabling file-level savepoints. + */ +object PartitionMetadataReader { + private val logger = LogManager.getLogger("com.scylladb.migrator.readers.PartitionMetadataReader") + + def readMetadata(spark: SparkSession, filePaths: Seq[String]): Seq[PartitionMetadata] = { + logger.info(s"Reading partition metadata from ${filePaths.size} file(s)") + val df = spark.read.parquet(filePaths: _*) + readMetadataFromDataFrame(df) + } + + def readMetadataFromDataFrame(df: DataFrame): Seq[PartitionMetadata] = + try { + logger.info("Extracting partition metadata from execution plan") + + val fileMap: Map[Int, Seq[String]] = PartitionMetadataExtractor.getPartitionFiles(df) + + val metadata = fileMap.flatMap { + case (partId, files) => + files.map(f => PartitionMetadata(partId, f)) + }.toSeq + + logger.info(s"Discovered ${metadata.size} partition-to-file mappings") + + if (logger.isDebugEnabled) { + val fileStats = metadata.groupBy(_.filename).view.mapValues(_.size) + logger.debug(s"Files distribution: ${fileStats.size} unique files") + fileStats.foreach { + case (file, partCount) => + logger.debug(s" File: $file -> $partCount partition(s)") + } + } + + metadata + + } catch { + case e: Exception => + logger.error(s"Failed to read partition metadata", e) + throw new RuntimeException(s"Could not read partition metadata: ${e.getMessage}", e) + } + + def buildFileToPartitionsMap(metadata: Seq[PartitionMetadata]): Map[String, Set[Int]] = { + val result = metadata + .groupBy(_.filename) + .view + .mapValues(_.map(_.partitionId).toSet) + .toMap + + logger.debug(s"Built file-to-partitions map with ${result.size} files") + result + } + + def buildPartitionToFileMap(metadata: Seq[PartitionMetadata]): Map[Int, Set[String]] = { + val result = metadata + .groupBy(_.partitionId) + .view + .mapValues(_.map(_.filename).toSet) + .toMap + + logger.debug(s"Built partition-to-file map with ${result.size} partitions") + result + } +} diff --git a/migrator/src/main/scala/com/scylladb/migrator/scylla/ScyllaMigrator.scala b/migrator/src/main/scala/com/scylladb/migrator/scylla/ScyllaMigrator.scala index 849cbb1e..53a1217e 100644 --- a/migrator/src/main/scala/com/scylladb/migrator/scylla/ScyllaMigrator.scala +++ b/migrator/src/main/scala/com/scylladb/migrator/scylla/ScyllaMigrator.scala @@ -3,8 +3,9 @@ package com.scylladb.migrator.scylla import com.datastax.spark.connector.rdd.partitioner.{ CassandraPartition, CqlTokenRange } import com.datastax.spark.connector.rdd.partitioner.dht.Token import com.datastax.spark.connector.writer.TokenRangeAccumulator +import com.scylladb.migrator.SavepointsManager import com.scylladb.migrator.config.{ MigratorConfig, SourceSettings, TargetSettings } -import com.scylladb.migrator.readers.TimestampColumns +import com.scylladb.migrator.readers.{ ParquetSavepointsManager, TimestampColumns } import com.scylladb.migrator.writers import org.apache.log4j.LogManager import org.apache.spark.sql.{ DataFrame, SparkSession } @@ -15,24 +16,30 @@ case class SourceDataFrame(dataFrame: DataFrame, timestampColumns: Option[TimestampColumns], savepointsSupported: Boolean) -object ScyllaMigrator { - val log = LogManager.getLogger("com.scylladb.migrator.scylla") +trait ScyllaMigratorBase { + protected val log = LogManager.getLogger("com.scylladb.migrator.scylla") - def migrate(migratorConfig: MigratorConfig, - target: TargetSettings.Scylla, - sourceDF: SourceDataFrame)(implicit spark: SparkSession): Unit = { + protected def externalSavepointsManager: Option[SavepointsManager] = None + + protected def createSavepointsManager( + migratorConfig: MigratorConfig, + sourceDF: SourceDataFrame + )(implicit spark: SparkSession): Option[SavepointsManager] + + protected def shouldCloseManager(manager: SavepointsManager): Boolean + + def migrate( + migratorConfig: MigratorConfig, + target: TargetSettings.Scylla, + sourceDF: SourceDataFrame + )(implicit spark: SparkSession): Unit = { log.info("Created source dataframe; resulting schema:") sourceDF.dataFrame.printSchema() - val maybeSavepointsManager = - if (!sourceDF.savepointsSupported) None - else { - val tokenRangeAccumulator = TokenRangeAccumulator.empty - spark.sparkContext.register(tokenRangeAccumulator, "Token ranges copied") - - Some(CqlSavepointsManager(migratorConfig, tokenRangeAccumulator)) - } + val maybeSavepointsManager = externalSavepointsManager.orElse( + createSavepointsManager(migratorConfig, sourceDF) + ) log.info( "We need to transfer: " + sourceDF.dataFrame.rdd.getNumPartitions + " partitions in total") @@ -71,12 +78,16 @@ object ScyllaMigrator { log.info("Starting write...") try { + val tokenRangeAccumulator = maybeSavepointsManager.flatMap { + case cqlManager: CqlSavepointsManager => Some(cqlManager.accumulator) + case _ => None + } writers.Scylla.writeDataframe( target, migratorConfig.getRenamesOrNil, sourceDF.dataFrame, sourceDF.timestampColumns, - maybeSavepointsManager.map(_.accumulator)) + tokenRangeAccumulator) } catch { case NonFatal(e) => // Catching everything on purpose to try and dump the accumulator state log.error( @@ -85,9 +96,52 @@ object ScyllaMigrator { } finally { for (savePointsManger <- maybeSavepointsManager) { savePointsManger.dumpMigrationState("final") - savePointsManger.close() + if (shouldCloseManager(savePointsManger)) { + savePointsManger.close() + } } } } +} + +object ScyllaMigrator extends ScyllaMigratorBase { + + protected override def createSavepointsManager( + migratorConfig: MigratorConfig, + sourceDF: SourceDataFrame + )(implicit spark: SparkSession): Option[SavepointsManager] = + if (!sourceDF.savepointsSupported) None + else { + val tokenRangeAccumulator = TokenRangeAccumulator.empty + spark.sparkContext.register(tokenRangeAccumulator, "Token ranges copied") + Some(CqlSavepointsManager(migratorConfig, tokenRangeAccumulator)) + } + + protected override def shouldCloseManager(manager: SavepointsManager): Boolean = true +} + +class ScyllaParquetMigrator(savepointsManager: ParquetSavepointsManager) + extends ScyllaMigratorBase { + + protected override def externalSavepointsManager: Option[SavepointsManager] = { + log.info("Using external Parquet savepoints manager") + Some(savepointsManager) + } + + protected override def createSavepointsManager( + migratorConfig: MigratorConfig, + sourceDF: SourceDataFrame + )(implicit spark: SparkSession): Option[SavepointsManager] = None + + protected override def shouldCloseManager(manager: SavepointsManager): Boolean = false +} +object ScyllaParquetMigrator { + def migrate( + migratorConfig: MigratorConfig, + target: TargetSettings.Scylla, + sourceDF: SourceDataFrame, + savepointsManager: ParquetSavepointsManager + )(implicit spark: SparkSession): Unit = + new ScyllaParquetMigrator(savepointsManager).migrate(migratorConfig, target, sourceDF) } diff --git a/migrator/src/main/scala/org/apache/spark/sql/execution/datasources/PartitionMetadataExtractor.scala b/migrator/src/main/scala/org/apache/spark/sql/execution/datasources/PartitionMetadataExtractor.scala new file mode 100644 index 00000000..2544c0b8 --- /dev/null +++ b/migrator/src/main/scala/org/apache/spark/sql/execution/datasources/PartitionMetadataExtractor.scala @@ -0,0 +1,53 @@ +// IMPORTANT: Must be in this package to access Spark internal API +package org.apache.spark.sql.execution.datasources + +import org.apache.log4j.LogManager +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.FileSourceScanExec + +/** + * Extracts partition-to-file mappings from the Spark execution plan. + * + * This uses Spark API to access the FileScanRDD, which already contains + * the partition-to-file mapping computed during query planning. + */ +object PartitionMetadataExtractor { + private val logger = + LogManager.getLogger("org.apache.spark.sql.execution.datasources.PartitionMetadataExtractor") + + def getPartitionFiles(df: DataFrame): Map[Int, Seq[String]] = { + logger.debug("Extracting partition-to-file mapping from execution plan") + + val plan = df.queryExecution.executedPlan + + val scanExecList = plan.collect { + case exec: FileSourceScanExec => exec + } + + val scanExec = scanExecList match { + case list if list.size == 1 => + list.head + case list if list.size > 1 => + val message = "Several FileSourceScanExec were found in plan" + logger.error(s"$message. Plan: ${plan.treeString}") + throw new IllegalArgumentException(message) + case list if list.isEmpty => + val message = "DataFrame is not based on file source (FileSourceScanExec not found in plan)" + logger.error(s"$message. Plan: ${plan.treeString}") + throw new IllegalArgumentException(message) + } + + val rdd = scanExec.inputRDD + + val partitionFiles = rdd.partitions.map { + case p: FilePartition => + val filePaths = p.files.map(_.filePath.toString).toSeq + (p.index, filePaths) + }.toMap + + logger.debug( + s"Extracted ${partitionFiles.size} partition mappings covering ${partitionFiles.values.flatten.toSet.size} unique files") + + partitionFiles + } +} diff --git a/tests/docker/.gitignore b/tests/docker/.gitignore index cd48166e..68fb0416 100644 --- a/tests/docker/.gitignore +++ b/tests/docker/.gitignore @@ -1,3 +1,2 @@ cassandra/ s3/ -spark-master/ diff --git a/tests/docker/spark-master/.gitignore b/tests/docker/spark-master/.gitignore new file mode 100644 index 00000000..76bedaea --- /dev/null +++ b/tests/docker/spark-master/.gitignore @@ -0,0 +1,5 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore + diff --git a/tests/src/test/configurations/parquet-to-scylla-legacy-comparison.yaml b/tests/src/test/configurations/parquet-to-scylla-legacy-comparison.yaml new file mode 100644 index 00000000..90e74d87 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-legacy-comparison.yaml @@ -0,0 +1,22 @@ +source: + type: parquet + path: /app/parquet/comparison-data + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: comparison + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/spark-master/comparison-savepoints + intervalSeconds: 300 + enableParquetFileTracking: false diff --git a/tests/src/test/configurations/parquet-to-scylla-legacy-singlefile.yaml b/tests/src/test/configurations/parquet-to-scylla-legacy-singlefile.yaml new file mode 100644 index 00000000..b33e4922 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-legacy-singlefile.yaml @@ -0,0 +1,22 @@ +source: + type: parquet + path: /app/parquet/legacy + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: singlefile + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/spark-master/legacy-savepoints + intervalSeconds: 300 + enableParquetFileTracking: false diff --git a/tests/src/test/configurations/parquet-to-scylla-legacy-test1.yaml b/tests/src/test/configurations/parquet-to-scylla-legacy-test1.yaml new file mode 100644 index 00000000..2864bd9f --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-legacy-test1.yaml @@ -0,0 +1,22 @@ +source: + type: parquet + path: /app/parquet/legacy + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: legacytest + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/spark-master/legacy-savepoints + intervalSeconds: 300 + enableParquetFileTracking: false diff --git a/tests/src/test/configurations/parquet-to-scylla-legacy-test2.yaml b/tests/src/test/configurations/parquet-to-scylla-legacy-test2.yaml new file mode 100644 index 00000000..2f248b7e --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-legacy-test2.yaml @@ -0,0 +1,22 @@ +source: + type: parquet + path: /app/parquet/legacy2 + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: legacytest2 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/spark-master/legacy-savepoints + intervalSeconds: 300 + enableParquetFileTracking: false diff --git a/tests/src/test/configurations/parquet-to-scylla-legacy-test3.yaml b/tests/src/test/configurations/parquet-to-scylla-legacy-test3.yaml new file mode 100644 index 00000000..d1e19771 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-legacy-test3.yaml @@ -0,0 +1,22 @@ +source: + type: parquet + path: /app/parquet/legacy3 + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: legacytest3 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/spark-master/legacy-noresume + intervalSeconds: 300 + enableParquetFileTracking: false diff --git a/tests/src/test/configurations/parquet-to-scylla-multipartition.yaml b/tests/src/test/configurations/parquet-to-scylla-multipartition.yaml new file mode 100644 index 00000000..4e5cdc53 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-multipartition.yaml @@ -0,0 +1,21 @@ +source: + type: parquet + path: /app/parquet/multipartition + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: multipartitiontest + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints/parquet-multipartition-test + intervalSeconds: 1 diff --git a/tests/src/test/configurations/parquet-to-scylla-multipartition2.yaml b/tests/src/test/configurations/parquet-to-scylla-multipartition2.yaml new file mode 100644 index 00000000..c8510d1e --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-multipartition2.yaml @@ -0,0 +1,21 @@ +source: + type: parquet + path: /app/parquet/multipartition2 + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: multipartitiontest2 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints/parquet-multipartition-test2 + intervalSeconds: 1 diff --git a/tests/src/test/configurations/parquet-to-scylla-newmode-comparison.yaml b/tests/src/test/configurations/parquet-to-scylla-newmode-comparison.yaml new file mode 100644 index 00000000..35fe4f14 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-newmode-comparison.yaml @@ -0,0 +1,22 @@ +source: + type: parquet + path: /app/parquet/comparison-data + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: comparison + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/spark-master/comparison-savepoints + intervalSeconds: 300 + enableParquetFileTracking: true diff --git a/tests/src/test/configurations/parquet-to-scylla-parallel-savepoints.yaml b/tests/src/test/configurations/parquet-to-scylla-parallel-savepoints.yaml new file mode 100644 index 00000000..091c9ae8 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-parallel-savepoints.yaml @@ -0,0 +1,21 @@ +source: + type: parquet + path: /app/parquet/parallel-savepoints + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: paralleltest2 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints/parallel-savepoints-test + intervalSeconds: 1 diff --git a/tests/src/test/configurations/parquet-to-scylla-parallel.yaml b/tests/src/test/configurations/parquet-to-scylla-parallel.yaml new file mode 100644 index 00000000..f2c6d89d --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-parallel.yaml @@ -0,0 +1,21 @@ +source: + type: parquet + path: /app/parquet/parallel + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: paralleltest + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints + intervalSeconds: 300 diff --git a/tests/src/test/configurations/parquet-to-scylla-resume.yaml b/tests/src/test/configurations/parquet-to-scylla-resume.yaml new file mode 100644 index 00000000..1c53e120 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-resume.yaml @@ -0,0 +1,21 @@ +source: + type: parquet + path: /app/parquet/resume + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: resumetest + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints/parquet-resume-test + intervalSeconds: 1 diff --git a/tests/src/test/configurations/parquet-to-scylla-resume2.yaml b/tests/src/test/configurations/parquet-to-scylla-resume2.yaml new file mode 100644 index 00000000..64a5baf1 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-resume2.yaml @@ -0,0 +1,21 @@ +source: + type: parquet + path: /app/parquet/resume2 + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: resumetest2 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints/parquet-resume-test2 + intervalSeconds: 1 diff --git a/tests/src/test/configurations/parquet-to-scylla-savepoints.yaml b/tests/src/test/configurations/parquet-to-scylla-savepoints.yaml new file mode 100644 index 00000000..2512c512 --- /dev/null +++ b/tests/src/test/configurations/parquet-to-scylla-savepoints.yaml @@ -0,0 +1,21 @@ +source: + type: parquet + path: /app/parquet/savepoints + +target: + type: scylla + host: scylla + port: 9042 + localDC: datacenter1 + credentials: + username: dummy + password: dummy + keyspace: test + table: savepointstest + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints/parquet-savepoints-test + intervalSeconds: 1 diff --git a/tests/src/test/scala/com/scylladb/migrator/alternator/StringSetAccumulatorTest.scala b/tests/src/test/scala/com/scylladb/migrator/alternator/StringSetAccumulatorTest.scala new file mode 100644 index 00000000..6b99e0cc --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/alternator/StringSetAccumulatorTest.scala @@ -0,0 +1,31 @@ +package com.scylladb.migrator.alternator + +class StringSetAccumulatorTest extends munit.FunSuite { + test("StringSetAccumulator basic functionality") { + val accumulator = StringSetAccumulator() + + assertEquals(accumulator.isZero, true) + assertEquals(accumulator.value, Set.empty[String]) + + accumulator.add("file1.parquet") + accumulator.add("file2.parquet") + + assertEquals(accumulator.value, Set("file1.parquet", "file2.parquet")) + assertEquals(accumulator.isZero, false) + + val copy = accumulator.copy() + assertEquals(copy.value, accumulator.value) + + accumulator.reset() + assertEquals(accumulator.isZero, true) + assertEquals(accumulator.value, Set.empty[String]) + } + + test("StringSetAccumulator merge functionality") { + val accumulator1 = StringSetAccumulator(Set("file1.parquet")) + val accumulator2 = StringSetAccumulator(Set("file2.parquet", "file3.parquet")) + + accumulator1.merge(accumulator2) + assertEquals(accumulator1.value, Set("file1.parquet", "file2.parquet", "file3.parquet")) + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/config/ParquetConfigSerializationTest.scala b/tests/src/test/scala/com/scylladb/migrator/config/ParquetConfigSerializationTest.scala new file mode 100644 index 00000000..cb20e38d --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/config/ParquetConfigSerializationTest.scala @@ -0,0 +1,143 @@ +package com.scylladb.migrator.config + +import java.nio.file.{Files, Paths} +import java.nio.charset.StandardCharsets + +class ParquetConfigSerializationTest extends munit.FunSuite { + + test("skipParquetFiles serialization to YAML") { + val config = MigratorConfig( + source = SourceSettings.Parquet( + path = "s3a://test-bucket/data/", + credentials = None, + region = None, + endpoint = None + ), + target = TargetSettings.Scylla( + host = "scylla-server", + port = 9042, + localDC = Some("datacenter1"), + credentials = None, + sslOptions = None, + keyspace = "test_keyspace", + table = "test_table", + connections = Some(16), + stripTrailingZerosForDecimals = false, + writeTTLInS = None, + writeWritetimestampInuS = None, + consistencyLevel = "LOCAL_QUORUM" + ), + renames = None, + savepoints = Savepoints(300, "/app/savepoints"), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = Some(Set( + "s3a://test-bucket/data/part-00001.parquet", + "s3a://test-bucket/data/part-00002.parquet", + "s3a://test-bucket/data/part-00003.parquet" + )), + validation = None + ) + + val yaml = config.render + + assert(yaml.contains("skipParquetFiles")) + assert(yaml.contains("part-00001.parquet")) + assert(yaml.contains("part-00002.parquet")) + assert(yaml.contains("part-00003.parquet")) + } + + test("skipParquetFiles deserialization from YAML") { + val yamlContent = """ +source: + type: parquet + path: "s3a://test-bucket/data/" + +target: + type: scylla + host: scylla-server + port: 9042 + keyspace: test_keyspace + table: test_table + localDC: datacenter1 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints + intervalSeconds: 300 + +skipParquetFiles: + - "s3a://test-bucket/data/part-00001.parquet" + - "s3a://test-bucket/data/part-00002.parquet" + +skipTokenRanges: null +skipSegments: null +renames: null +validation: null +""" + + val tempFile = Files.createTempFile("test-config", ".yaml") + try { + Files.write(tempFile, yamlContent.getBytes(StandardCharsets.UTF_8)) + + val config = MigratorConfig.loadFrom(tempFile.toString) + + assertEquals(config.skipParquetFiles.isDefined, true) + assertEquals(config.skipParquetFiles.get.size, 2) + assert(config.skipParquetFiles.get.contains("s3a://test-bucket/data/part-00001.parquet")) + assert(config.skipParquetFiles.get.contains("s3a://test-bucket/data/part-00002.parquet")) + + assertEquals(config.getSkipParquetFilesOrEmptySet.size, 2) + + } finally { + Files.delete(tempFile) + } + } + + test("skipParquetFiles round-trip serialization") { + val originalFiles = Set( + "s3a://bucket/file1.parquet", + "file:///local/file2.parquet", + "hdfs://namenode/file3.parquet" + ) + + val config1 = MigratorConfig( + source = SourceSettings.Parquet("s3a://bucket/", None, None, None), + target = TargetSettings.Scylla( + host = "localhost", + port = 9042, + localDC = Some("dc1"), + credentials = None, + sslOptions = None, + keyspace = "ks", + table = "tbl", + connections = Some(8), + stripTrailingZerosForDecimals = false, + writeTTLInS = None, + writeWritetimestampInuS = None, + consistencyLevel = "LOCAL_QUORUM" + ), + renames = None, + savepoints = Savepoints(300, "/tmp"), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = Some(originalFiles), + validation = None + ) + + val yaml = config1.render + + val tempFile = Files.createTempFile("roundtrip-config", ".yaml") + try { + Files.write(tempFile, yaml.getBytes(StandardCharsets.UTF_8)) + val config2 = MigratorConfig.loadFrom(tempFile.toString) + + assertEquals(config2.skipParquetFiles.get, originalFiles) + } finally { + Files.delete(tempFile) + } + } + +} diff --git a/tests/src/test/scala/com/scylladb/migrator/readers/FileCompletionListenerTest.scala b/tests/src/test/scala/com/scylladb/migrator/readers/FileCompletionListenerTest.scala new file mode 100644 index 00000000..56434b42 --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/readers/FileCompletionListenerTest.scala @@ -0,0 +1,374 @@ +package com.scylladb.migrator.readers + +import com.scylladb.migrator.config.{MigratorConfig, SourceSettings} +import org.apache.spark.scheduler.{SparkListenerTaskEnd, TaskInfo} +import org.apache.spark.{Success, TaskEndReason} +import org.apache.spark.sql.SparkSession +import java.nio.file.Files + +class FileCompletionListenerTest extends munit.FunSuite { + + implicit val spark: SparkSession = SparkSession + .builder() + .appName("FileCompletionListenerTest") + .master("local[*]") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + + override def afterAll(): Unit = { + spark.stop() + super.afterAll() + } + + def createMockTaskEnd(partitionId: Int, stageId: Int = 0, success: Boolean = true): SparkListenerTaskEnd = { + val taskInfo = new TaskInfo( + taskId = partitionId.toLong, + index = partitionId, + attemptNumber = 0, + partitionId = partitionId, + launchTime = System.currentTimeMillis(), + executorId = "executor-1", + host = "localhost", + taskLocality = org.apache.spark.scheduler.TaskLocality.PROCESS_LOCAL, + speculative = false + ) + + val reason: TaskEndReason = if (success) Success else org.apache.spark.TaskKilled("test") + + new SparkListenerTaskEnd( + stageId = stageId, + stageAttemptId = 0, + taskType = "ResultTask", + reason = reason, + taskInfo = taskInfo, + taskExecutorMetrics = null, + taskMetrics = null + ) + } + + test("FileCompletionListener tracks single-partition files") { + val tempDir = Files.createTempDirectory("savepoints-listener-test") + + try { + val config = MigratorConfig( + source = SourceSettings.Parquet("dummy", None, None, None), + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, tempDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + // Setup: 3 files, each with 1 partition + val partitionToFile = Map( + 0 -> Set("file1.parquet"), + 1 -> Set("file2.parquet"), + 2 -> Set("file3.parquet") + ) + + val fileToPartitions = Map( + "file1.parquet" -> Set(0), + "file2.parquet" -> Set(1), + "file3.parquet" -> Set(2) + ) + + val listener = new FileCompletionListener( + partitionToFile, + fileToPartitions, + manager + ) + + assertEquals(listener.getCompletedFilesCount, 0) + assertEquals(listener.getTotalFilesCount, 3) + + listener.onTaskEnd(createMockTaskEnd(0)) + + assertEquals(listener.getCompletedFilesCount, 1) + + listener.onTaskEnd(createMockTaskEnd(1)) + + assertEquals(listener.getCompletedFilesCount, 2) + + listener.onTaskEnd(createMockTaskEnd(2)) + + assertEquals(listener.getCompletedFilesCount, 3) + + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("FileCompletionListener tracks multi-partition files") { + val tempDir = Files.createTempDirectory("savepoints-listener-multipart-test") + + try { + val config = MigratorConfig( + source = SourceSettings.Parquet("dummy", None, None, None), + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, tempDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + val partitionToFile = Map( + 0 -> Set("file1.parquet"), + 1 -> Set("file1.parquet"), + 2 -> Set("file1.parquet"), + 3 -> Set("file2.parquet"), + 4 -> Set("file2.parquet") + ) + + val fileToPartitions = Map( + "file1.parquet" -> Set(0, 1, 2), + "file2.parquet" -> Set(3, 4) + ) + + val listener = new FileCompletionListener( + partitionToFile, + fileToPartitions, + manager + ) + + assertEquals(listener.getCompletedFilesCount, 0) + + listener.onTaskEnd(createMockTaskEnd(0)) + assertEquals(listener.getCompletedFilesCount, 0, "File1 should not be complete yet") + + listener.onTaskEnd(createMockTaskEnd(1)) + assertEquals(listener.getCompletedFilesCount, 0, "File1 should still not be complete") + + listener.onTaskEnd(createMockTaskEnd(2)) + assertEquals(listener.getCompletedFilesCount, 1, "File1 should be complete now") + + listener.onTaskEnd(createMockTaskEnd(3)) + assertEquals(listener.getCompletedFilesCount, 1, "File2 should not be complete yet") + + listener.onTaskEnd(createMockTaskEnd(4)) + assertEquals(listener.getCompletedFilesCount, 2, "Both files should be complete") + + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("FileCompletionListener handles failed tasks correctly") { + val tempDir = Files.createTempDirectory("savepoints-listener-failure-test") + + try { + val config = MigratorConfig( + source = SourceSettings.Parquet("dummy", None, None, None), + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, tempDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + val partitionToFile = Map( + 0 -> Set("file1.parquet"), + 1 -> Set("file2.parquet") + ) + + val fileToPartitions = Map( + "file1.parquet" -> Set(0), + "file2.parquet" -> Set(1) + ) + + val listener = new FileCompletionListener( + partitionToFile, + fileToPartitions, + manager + ) + + listener.onTaskEnd(createMockTaskEnd(0, success = false)) + + assertEquals(listener.getCompletedFilesCount, 0) + + listener.onTaskEnd(createMockTaskEnd(0, success = true)) + + assertEquals(listener.getCompletedFilesCount, 1) + + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("FileCompletionListener is idempotent for duplicate task completions") { + val tempDir = Files.createTempDirectory("savepoints-listener-idempotent-test") + + try { + val config = MigratorConfig( + source = SourceSettings.Parquet("dummy", None, None, None), + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, tempDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + val partitionToFile = Map(0 -> Set("file1.parquet")) + val fileToPartitions = Map("file1.parquet" -> Set(0)) + + val listener = new FileCompletionListener( + partitionToFile, + fileToPartitions, + manager + ) + + listener.onTaskEnd(createMockTaskEnd(0)) + listener.onTaskEnd(createMockTaskEnd(0)) + listener.onTaskEnd(createMockTaskEnd(0)) + + assertEquals(listener.getCompletedFilesCount, 1) + + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("FileCompletionListener provides accurate progress report") { + val tempDir = Files.createTempDirectory("savepoints-listener-report-test") + + try { + val config = MigratorConfig( + source = SourceSettings.Parquet("dummy", None, None, None), + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, tempDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + val partitionToFile = Map( + 0 -> Set("file1.parquet"), + 1 -> Set("file1.parquet"), + 2 -> Set("file2.parquet") + ) + + val fileToPartitions = Map( + "file1.parquet" -> Set(0, 1), + "file2.parquet" -> Set(2) + ) + + val listener = new FileCompletionListener( + partitionToFile, + fileToPartitions, + manager + ) + + val initialReport = listener.getProgressReport + assert(initialReport.contains("0/2 files")) + + listener.onTaskEnd(createMockTaskEnd(0)) + val midReport = listener.getProgressReport + assert(midReport.contains("0/2 files"), "File not complete until all partitions done") + + listener.onTaskEnd(createMockTaskEnd(1)) + val file1CompleteReport = listener.getProgressReport + assert(file1CompleteReport.contains("1/2 files")) + + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("FileCompletionListener handles multiple files per partition") { + val tempDir = Files.createTempDirectory("savepoints-listener-multi-file-partition-test") + + try { + val config = MigratorConfig( + source = SourceSettings.Parquet("dummy", None, None, None), + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, tempDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + val partitionToFile = Map(0 -> Set("file1.parquet", "file2.parquet")) + val fileToPartitions = Map( + "file1.parquet" -> Set(0), + "file2.parquet" -> Set(0) + ) + + val listener = new FileCompletionListener( + partitionToFile, + fileToPartitions, + manager + ) + + listener.onTaskEnd(createMockTaskEnd(0)) + + assertEquals(listener.getCompletedFilesCount, 2) + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/readers/ParquetModeSelectionTest.scala b/tests/src/test/scala/com/scylladb/migrator/readers/ParquetModeSelectionTest.scala new file mode 100644 index 00000000..3f018f4f --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/readers/ParquetModeSelectionTest.scala @@ -0,0 +1,165 @@ +package com.scylladb.migrator.readers + +import com.scylladb.migrator.config.{MigratorConfig, Savepoints, SourceSettings, TargetSettings} +import java.nio.file.{Files, Paths} +import java.nio.charset.StandardCharsets + +class ParquetModeSelectionTest extends munit.FunSuite { + + test("savepoints.enableParquetFileTracking defaults to true when not specified") { + val yamlContent = """ +source: + type: parquet + path: "s3a://test-bucket/data/" + +target: + type: scylla + host: scylla-server + port: 9042 + keyspace: test_keyspace + table: test_table + localDC: datacenter1 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints + intervalSeconds: 300 +""" + + val tempFile = Files.createTempFile("test-config-default", ".yaml") + try { + Files.write(tempFile, yamlContent.getBytes(StandardCharsets.UTF_8)) + + val config = MigratorConfig.loadFrom(tempFile.toString) + + assertEquals(config.savepoints.enableParquetFileTracking, true, + "enableParquetFileTracking should default to true") + + } finally { + Files.delete(tempFile) + } + } + + test("savepoints.enableParquetFileTracking can be explicitly set to true") { + val yamlContent = """ +source: + type: parquet + path: "s3a://test-bucket/data/" + +target: + type: scylla + host: scylla-server + port: 9042 + keyspace: test_keyspace + table: test_table + localDC: datacenter1 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints + intervalSeconds: 300 + enableParquetFileTracking: true +""" + + val tempFile = Files.createTempFile("test-config-new-mode", ".yaml") + try { + Files.write(tempFile, yamlContent.getBytes(StandardCharsets.UTF_8)) + + val config = MigratorConfig.loadFrom(tempFile.toString) + + assertEquals(config.savepoints.enableParquetFileTracking, true, + "enableParquetFileTracking should be true when explicitly set") + + } finally { + Files.delete(tempFile) + } + } + + test("savepoints.enableParquetFileTracking can be set to false for legacy mode") { + val yamlContent = """ +source: + type: parquet + path: "s3a://test-bucket/data/" + +target: + type: scylla + host: scylla-server + port: 9042 + keyspace: test_keyspace + table: test_table + localDC: datacenter1 + consistencyLevel: LOCAL_QUORUM + connections: 16 + stripTrailingZerosForDecimals: false + +savepoints: + path: /app/savepoints + intervalSeconds: 300 + enableParquetFileTracking: false +""" + + val tempFile = Files.createTempFile("test-config-legacy-mode", ".yaml") + try { + Files.write(tempFile, yamlContent.getBytes(StandardCharsets.UTF_8)) + + val config = MigratorConfig.loadFrom(tempFile.toString) + + assertEquals(config.savepoints.enableParquetFileTracking, false, + "enableParquetFileTracking should be false when explicitly set") + + } finally { + Files.delete(tempFile) + } + } + + test("savepoints configuration round-trip with enableParquetFileTracking") { + val savepoints1 = Savepoints( + intervalSeconds = 60, + path = "/tmp/savepoints", + enableParquetFileTracking = false + ) + + val config1 = MigratorConfig( + source = SourceSettings.Parquet("s3a://bucket/", None, None, None), + target = TargetSettings.Scylla( + host = "localhost", + port = 9042, + localDC = Some("dc1"), + credentials = None, + sslOptions = None, + keyspace = "ks", + table = "tbl", + connections = Some(8), + stripTrailingZerosForDecimals = false, + writeTTLInS = None, + writeWritetimestampInuS = None, + consistencyLevel = "LOCAL_QUORUM" + ), + renames = None, + savepoints = savepoints1, + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val yaml = config1.render + + val tempFile = Files.createTempFile("roundtrip-savepoints", ".yaml") + try { + Files.write(tempFile, yaml.getBytes(StandardCharsets.UTF_8)) + val config2 = MigratorConfig.loadFrom(tempFile.toString) + + assertEquals(config2.savepoints.intervalSeconds, savepoints1.intervalSeconds) + assertEquals(config2.savepoints.path, savepoints1.path) + assertEquals(config2.savepoints.enableParquetFileTracking, savepoints1.enableParquetFileTracking, + "enableParquetFileTracking should survive round-trip serialization") + } finally { + Files.delete(tempFile) + } + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/readers/PartitionMetadataReaderTest.scala b/tests/src/test/scala/com/scylladb/migrator/readers/PartitionMetadataReaderTest.scala new file mode 100644 index 00000000..bf18d61a --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/readers/PartitionMetadataReaderTest.scala @@ -0,0 +1,182 @@ +package com.scylladb.migrator.readers + +import org.apache.spark.sql.SparkSession +import java.nio.file.Files + +class PartitionMetadataReaderTest extends munit.FunSuite { + + implicit val spark: SparkSession = SparkSession + .builder() + .appName("PartitionMetadataReaderTest") + .master("local[*]") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + + override def afterAll(): Unit = { + spark.stop() + super.afterAll() + } + + test("readMetadata returns partition-to-file mappings") { + val tempDir = Files.createTempDirectory("partition-metadata-test") + + try { + import spark.implicits._ + + // Create multiple parquet files + val testData1 = (1 to 10).map(i => (i, s"data$i")).toDF("id", "name") + val testData2 = (11 to 20).map(i => (i, s"data$i")).toDF("id", "name") + + val file1Path = tempDir.resolve("file1.parquet") + val file2Path = tempDir.resolve("file2.parquet") + + testData1.write.parquet(file1Path.toString) + testData2.write.parquet(file2Path.toString) + + val files = Parquet.listParquetFiles(spark, tempDir.toString) + val metadata = PartitionMetadataReader.readMetadata(spark, files) + + assert(metadata.nonEmpty, "Metadata should not be empty") + + metadata.foreach { pm => + assert(pm.partitionId >= 0, s"Partition ID should be non-negative: ${pm.partitionId}") + assert(pm.filename.nonEmpty, s"Filename should not be empty") + } + + val uniqueFiles = metadata.map(_.filename).toSet + assert(uniqueFiles.size >= 2, s"Should have at least 2 unique files, got ${uniqueFiles.size}") + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("buildFileToPartitionsMap creates correct mapping") { + val metadata = Seq( + PartitionMetadata(0, "file1.parquet"), + PartitionMetadata(1, "file1.parquet"), + PartitionMetadata(2, "file2.parquet"), + PartitionMetadata(3, "file3.parquet") + ) + + val fileToPartitions = PartitionMetadataReader.buildFileToPartitionsMap(metadata) + + assertEquals(fileToPartitions.size, 3) + assertEquals(fileToPartitions("file1.parquet"), Set(0, 1)) + assertEquals(fileToPartitions("file2.parquet"), Set(2)) + assertEquals(fileToPartitions("file3.parquet"), Set(3)) + } + + test("buildPartitionToFileMap creates correct mapping") { + val metadata = Seq( + PartitionMetadata(0, "file1.parquet"), + PartitionMetadata(1, "file1.parquet"), + PartitionMetadata(2, "file2.parquet") + ) + + val partitionToFile = PartitionMetadataReader.buildPartitionToFileMap(metadata) + + assertEquals(partitionToFile.size, 3) + assertEquals(partitionToFile(0), Set("file1.parquet")) + assertEquals(partitionToFile(1), Set("file1.parquet")) + assertEquals(partitionToFile(2), Set("file2.parquet")) + } + + test("buildPartitionToFileMap groups multiple files per partition") { + val metadata = Seq( + PartitionMetadata(0, "file1.parquet"), + PartitionMetadata(0, "file2.parquet"), + PartitionMetadata(1, "file3.parquet") + ) + + val partitionToFile = PartitionMetadataReader.buildPartitionToFileMap(metadata) + + assertEquals(partitionToFile(0), Set("file1.parquet", "file2.parquet")) + assertEquals(partitionToFile(1), Set("file3.parquet")) + } + + test("file filtering logic works correctly") { + val allFiles = Seq("file1.parquet", "file2.parquet", "file3.parquet") + val processedFiles = Set("file1.parquet", "file3.parquet") + + val filesToProcess = allFiles.filterNot(processedFiles.contains) + + assertEquals(filesToProcess.size, 1) + assertEquals(filesToProcess.head, "file2.parquet") + } + + test("file filtering returns all files when skipFiles is empty") { + val allFiles = Seq("file1.parquet", "file2.parquet") + val skipFiles = Set.empty[String] + + val filesToProcess = allFiles.filterNot(skipFiles.contains) + + assertEquals(filesToProcess.size, allFiles.size) + assertEquals(filesToProcess, allFiles) + } + + test("reading metadata with file filtering") { + val tempDir = Files.createTempDirectory("partition-metadata-filtering-test") + + try { + import spark.implicits._ + + val testData1 = (1 to 5).map(i => (i, s"data$i")).toDF("id", "name") + val testData2 = (6 to 10).map(i => (i, s"data$i")).toDF("id", "name") + + val file1Path = tempDir.resolve("file1.parquet") + val file2Path = tempDir.resolve("file2.parquet") + + testData1.write.parquet(file1Path.toString) + testData2.write.parquet(file2Path.toString) + + val allFiles = Parquet.listParquetFiles(spark, tempDir.toString) + val allMetadata = PartitionMetadataReader.readMetadata(spark, allFiles) + + val fileToSkip = allFiles.head + val filesToProcess = allFiles.filterNot(_ == fileToSkip) + + val filteredMetadata = PartitionMetadataReader.readMetadata(spark, filesToProcess) + + assert(filteredMetadata.size < allMetadata.size) + + filteredMetadata.foreach { pm => + assert(pm.filename != fileToSkip, s"Skipped file ${fileToSkip} should not appear in filtered metadata") + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("readMetadata handles single partition per file") { + val tempDir = Files.createTempDirectory("single-partition-test") + + try { + import spark.implicits._ + + val testData = Seq((1, "data")).toDF("id", "name") + val filePath = tempDir.resolve("small.parquet") + + testData.write.parquet(filePath.toString) + + val files = Parquet.listParquetFiles(spark, tempDir.toString) + val metadata = PartitionMetadataReader.readMetadata(spark, files) + + assert(metadata.nonEmpty, "Metadata should contain at least one entry") + + val fileToPartitions = PartitionMetadataReader.buildFileToPartitionsMap(metadata) + + assertEquals(fileToPartitions.size, 1) + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetLegacyModeTest.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetLegacyModeTest.scala new file mode 100644 index 00000000..6d66f95a --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetLegacyModeTest.scala @@ -0,0 +1,257 @@ +package com.scylladb.migrator.scylla + +import com.datastax.oss.driver.api.querybuilder.QueryBuilder +import com.scylladb.migrator.SparkUtils.successfullyPerformMigration +import com.scylladb.migrator.config.MigratorConfig +import java.nio.file.Files +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ +import scala.util.chaining._ + +class ParquetLegacyModeTest extends ParquetMigratorSuite { + + override val munitTimeout: FiniteDuration = 2.minutes + + FunFixture + .map2(withTable("legacytest"), withParquetDir("legacy")) + .test("Legacy mode (enableParquetFileTracking=false) migrates data correctly") { + case (tableName, parquetRoot) => + val parquetDir = parquetRoot.resolve("legacy") + Files.createDirectories(parquetDir) + + // Create multiple parquet files to ensure parallel processing works + val parquetFiles = List( + parquetDir.resolve("file-1.parquet") -> List( + TestRecord("1", "alpha", 10), + TestRecord("2", "beta", 20) + ), + parquetDir.resolve("file-2.parquet") -> List( + TestRecord("3", "gamma", 30), + TestRecord("4", "delta", 40) + ), + parquetDir.resolve("file-3.parquet") -> List( + TestRecord("5", "epsilon", 50) + ) + ) + + parquetFiles.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + // Run migration with legacy mode + successfullyPerformMigration("parquet-to-scylla-legacy-test1.yaml") + + // Verify all data was migrated correctly + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val expectedRows = parquetFiles.flatMap(_._2).map(row => row.id -> row).toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, expectedRows.size, "All rows should be migrated in legacy mode") + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, expectedRows(id), s"Row $id should match expected data") + } + } + } + + withTableAndSavepoints("legacytest2", "legacy2", "legacy-savepoints").test( + "Legacy mode does NOT create file-level savepoints" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("legacy2") + Files.createDirectories(parquetDir) + + val parquetFiles = List( + parquetDir.resolve("data-1.parquet") -> List( + TestRecord("1", "foo", 100) + ), + parquetDir.resolve("data-2.parquet") -> List( + TestRecord("2", "bar", 200) + ) + ) + + parquetFiles.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + // Run migration with legacy mode + successfullyPerformMigration("parquet-to-scylla-legacy-test2.yaml") + + // Verify data migrated + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val rowCount = targetScylla().execute(selectAllStatement).all().size() + assertEquals(rowCount, 2, "Should have migrated 2 rows") + + // Check savepoint file if it exists + val maybeSavepointFile = findLatestSavepoint(savepointsDir) + + maybeSavepointFile match { + case Some(savepointFile) => + // If savepoint exists, it should NOT contain skipParquetFiles + val savepointConfig = MigratorConfig.loadFrom(savepointFile.toString) + assertEquals( + savepointConfig.skipParquetFiles, + None, + "Legacy mode should NOT track skipParquetFiles in savepoints" + ) + case None => + // It's also acceptable if no savepoint is created at all + // This is the expected behavior when savepointsSupported = false + () + } + } + + withTableAndSavepoints("legacytest3", "legacy3", "legacy-noresume").test( + "Legacy mode does NOT support resume - data is re-processed on restart" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("legacy3") + Files.createDirectories(parquetDir) + + writeParquetTestFile( + parquetDir.resolve("data.parquet"), + List(TestRecord("unique-id", "test-data", 999)) + ) + + // First run + successfullyPerformMigration("parquet-to-scylla-legacy-test3.yaml") + + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val rowCountAfterFirstRun = targetScylla().execute(selectAllStatement).all().size() + assertEquals(rowCountAfterFirstRun, 1, "Should have 1 row after first run") + + // Second run - in legacy mode, this should re-process the same file + // Note: This will either duplicate data OR fail with unique constraint + // depending on table schema. Since our test schema has 'id' as primary key, + // the data should be idempotent (same id overwrites). + successfullyPerformMigration("parquet-to-scylla-legacy-test3.yaml") + + val rowCountAfterSecondRun = targetScylla().execute(selectAllStatement).all().size() + + // In legacy mode without resume: + // - Data is re-processed + // - But since we have primary key, it just overwrites + // - So count stays the same, but this proves no "skip" happened + assertEquals( + rowCountAfterSecondRun, + 1, + "Row count should still be 1 (overwritten, not skipped)" + ) + + // The key difference from new mode: no file was marked as "processed" + // In new mode, the second run would log "No Parquet files to process" + // In legacy mode, it re-reads and re-processes everything + } + + withTableAndSavepoints("comparison", "comparison-data", "comparison-savepoints").test( + "Legacy and new modes produce identical results on same dataset" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("comparison-data") + Files.createDirectories(parquetDir) + + // Create a representative dataset + val testData = List( + parquetDir.resolve("batch-a.parquet") -> List( + TestRecord("id-1", "data-a", 111), + TestRecord("id-2", "data-b", 222) + ), + parquetDir.resolve("batch-b.parquet") -> List( + TestRecord("id-3", "data-c", 333), + TestRecord("id-4", "data-d", 444) + ), + parquetDir.resolve("batch-c.parquet") -> List( + TestRecord("id-5", "data-e", 555) + ) + ) + + testData.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + // Run with legacy mode first + successfullyPerformMigration("parquet-to-scylla-legacy-comparison.yaml") + + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + // Capture results from legacy mode + val legacyResults = targetScylla().execute(selectAllStatement).all().asScala + .map { row => + TestRecord( + row.getString("id"), + row.getString("foo"), + row.getInt("bar") + ) + } + .toSet + + assertEquals(legacyResults.size, 5, "Legacy mode should migrate all 5 rows") + + // Clear the table for second run + val truncateStatement = QueryBuilder.truncate(keyspace, tableName).build() + targetScylla().execute(truncateStatement) + + // Run with new mode + successfullyPerformMigration("parquet-to-scylla-newmode-comparison.yaml") + + // Capture results from new mode + val newModeResults = targetScylla().execute(selectAllStatement).all().asScala + .map { row => + TestRecord( + row.getString("id"), + row.getString("foo"), + row.getInt("bar") + ) + } + .toSet + + + assertEquals( + newModeResults, + legacyResults, + "New mode and legacy mode should produce identical data" + ) + } + + FunFixture + .map2(withTable("singlefile"), withParquetDir("legacy")) + .test("Legacy mode migrates single file correctly") { + case (tableName, parquetRoot) => + val parquetDir = parquetRoot.resolve("legacy") + Files.createDirectories(parquetDir) + + // Single file scenario + writeParquetTestFile( + parquetDir.resolve("single.parquet"), + List(TestRecord("only-one", "single-file", 42)) + ) + + successfullyPerformMigration("parquet-to-scylla-legacy-singlefile.yaml") + + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val rows = targetScylla().execute(selectAllStatement).all().asScala + assertEquals(rows.size, 1, "Should migrate the single row") + assertEquals(rows.head.getString("id"), "only-one") + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetMigratorSuite.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetMigratorSuite.scala new file mode 100644 index 00000000..3206cbec --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetMigratorSuite.scala @@ -0,0 +1,130 @@ +package com.scylladb.migrator.scylla + +import com.datastax.oss.driver.api.querybuilder.SchemaBuilder +import com.github.mjakubowski84.parquet4s.ParquetWriter +import com.scylladb.migrator.CassandraUtils.dropAndRecreateTable +import org.apache.parquet.hadoop.ParquetFileWriter + +import java.nio.file.{Files, Path, Paths} +import scala.jdk.CollectionConverters._ +import scala.util.Using + +abstract class ParquetMigratorSuite extends MigratorSuite(sourcePort = 0) { + + // parquet4s automatically derives the Parquet schema from this case class definition. + // This definition must remain consistent with the table schema created by dropAndRecreateTable, + // which has columns: id TEXT, foo TEXT, bar INT. + case class TestRecord(id: String, foo: String, bar: Int) + + protected val parquetHostRoot: Path = Paths.get("docker/parquet") + protected val savepointsHostRoot: Path = Paths.get("docker/spark-master") + + override def munitFixtures: Seq[Fixture[_]] = Seq(targetScylla) + + // Override withTable to only create table in target (no source Cassandra for Parquet tests) + override def withTable(name: String, renames: Map[String, String] = Map.empty): FunFixture[String] = + FunFixture( + setup = { _ => + try { + // Only create table in target ScyllaDB, not in source + dropAndRecreateTable( + targetScylla(), + keyspace, + name, + columnName = originalName => renames.getOrElse(originalName, originalName)) + } catch { + case any: Throwable => + fail(s"Something did not work as expected", any) + } + name + }, + teardown = { _ => + val dropTableQuery = SchemaBuilder.dropTable(keyspace, name).build() + targetScylla().execute(dropTableQuery) + () + } + ) + + def withParquetDir(parquetDir: String): FunFixture[Path] = + FunFixture( + setup = { _ => + ensureEmptyDirectory(parquetHostRoot) + parquetHostRoot + }, + teardown = { directory => + ensureEmptyDirectory(directory) + () + } + ) + + def withSavepointsDir(savepointsDir: String): FunFixture[Path] = + FunFixture( + setup = { _ => + val directory = savepointsHostRoot.resolve(savepointsDir) + ensureEmptyDirectory(directory) + directory + }, + teardown = { directory => + ensureEmptyDirectory(directory) + () + } + ) + + def withParquetAndSavepoints(parquetDir: String, savepointsDir: String): FunFixture[(Path, Path)] = + FunFixture.map2(withParquetDir(parquetDir), withSavepointsDir(savepointsDir)) + + def withTableAndSavepoints(tableName: String, parquetDir: String, savepointsDir: String): FunFixture[(String, (Path, Path))] = + FunFixture.map2(withTable(tableName), withParquetAndSavepoints(parquetDir, savepointsDir)) + + def writeParquetTestFile(path: Path, data: List[TestRecord]): Unit = { + ParquetWriter.writeAndClose( + path.toString, + data, + ParquetWriter.Options(writeMode = ParquetFileWriter.Mode.OVERWRITE) + ) + } + + def toContainerParquetUri(path: Path): String = { + require(path.startsWith(parquetHostRoot), s"Unexpected parquet file location: $path") + val relative = parquetHostRoot.relativize(path) + Paths.get("/app/parquet").resolve(relative).toUri.toString + } + + def listDataFiles(root: Path): Set[Path] = + Using.resource(Files.walk(root)) { stream => + stream.iterator().asScala + .filter(path => Files.isRegularFile(path)) + .filter(_.getFileName.toString.endsWith(".parquet")) + .toSet + } + + def findLatestSavepoint(directory: Path): Option[Path] = + if (!Files.exists(directory)) None + else + Using.resource(Files.list(directory)) { stream => + stream.iterator().asScala + .filter(path => Files.isRegularFile(path)) + .filter(_.getFileName.toString.startsWith("savepoint_")) + .toSeq + }.sortBy(path => Files.getLastModifiedTime(path).toMillis) + .lastOption + + private def ensureEmptyDirectory(directory: Path): Unit = { + if (Files.exists(directory)) { + Using.resource(Files.list(directory)) { stream => + stream.iterator().asScala.foreach(deleteRecursively) + } + } + Files.createDirectories(directory) + } + + private def deleteRecursively(path: Path): Unit = { + if (Files.isDirectory(path)) { + Using.resource(Files.list(path)) { stream => + stream.iterator().asScala.foreach(deleteRecursively) + } + } + Files.deleteIfExists(path) + } + +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetMultiPartitionTest.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetMultiPartitionTest.scala new file mode 100644 index 00000000..818e0681 --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetMultiPartitionTest.scala @@ -0,0 +1,210 @@ +package com.scylladb.migrator.scylla + +import com.datastax.oss.driver.api.querybuilder.QueryBuilder +import com.scylladb.migrator.config.MigratorConfig + +import java.nio.file.Files +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ +import scala.sys.process.Process +import scala.util.chaining._ + +class ParquetMultiPartitionTest extends ParquetMigratorSuite { + + private val configFileName: String = "parquet-to-scylla-multipartition.yaml" + private val configFileName2: String = "parquet-to-scylla-multipartition2.yaml" + + override val munitTimeout: FiniteDuration = 2.minutes + + /** + * Run migration with custom Spark configuration to force file splitting. + * This sets spark.sql.files.maxPartitionBytes to a very small value (64KB) + * to ensure even small files get split into multiple partitions. + */ + private def performMigrationWithSmallPartitions( + configFile: String = configFileName + ): Unit = { + Process( + Seq( + "docker", + "compose", + "-f", + "../docker-compose-tests.yml", + "exec", + "spark-master", + "/spark/bin/spark-submit", + "--class", + "com.scylladb.migrator.Migrator", + "--master", + "spark://spark-master:7077", + "--conf", + "spark.driver.host=spark-master", + "--conf", + s"spark.scylla.config=/app/configurations/${configFile}", + "--conf", + "spark.sql.files.maxPartitionBytes=65536", // 64KB - forces multiple partitions + "--conf", + "spark.sql.files.openCostInBytes=4096", // 4KB - small open cost + "--executor-cores", "2", + "--executor-memory", "4G", + "/jars/scylla-migrator-assembly.jar" + ) + ).run() + .exitValue() + .ensuring(statusCode => statusCode == 0, "Spark job with small partitions failed") + () + } + + withTableAndSavepoints("multipartitiontest", "multipartition", "parquet-multipartition-test").test( + "Large parquet file split into multiple partitions is tracked correctly" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("multipartition") + Files.createDirectories(parquetDir) + + // Create a SINGLE large parquet file with many rows + // Each row has relatively large string to ensure file size > 64KB + val largeFilePath = parquetDir.resolve("large-file.parquet") + val largeData = (1 to 500).map { i => + // Each record is ~200 bytes, total ~100KB file + TestRecord( + id = f"id-$i%05d", + foo = s"data-$i-" + ("x" * 150), // Large string to increase file size + bar = i * 10 + ) + }.toList + + writeParquetTestFile(largeFilePath, largeData) + + val fileSizeBytes = Files.size(largeFilePath) + // With 500 rows and ~200 bytes each, file should be ~100KB + assert(fileSizeBytes > 65536, s"File size $fileSizeBytes should be > 64KB to force splitting") + + val expectedProcessedFiles = listDataFiles(parquetDir).map(toContainerParquetUri) + assertEquals(expectedProcessedFiles.size, 1, "Should have exactly 1 parquet file") + + // Run migration with small partition size to force splitting + performMigrationWithSmallPartitions() + + // Verify all data was migrated correctly + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val expectedRows = largeData.map(row => row.id -> row).toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, expectedRows.size, "All 500 rows should be migrated") + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, expectedRows(id)) + } + } + + // Verify savepoint was created and contains the file + val savepointFile = findLatestSavepoint(savepointsDir) + .getOrElse(fail("Savepoint file was not created")) + + val savepointConfig = MigratorConfig.loadFrom(savepointFile.toString) + val skipFiles = savepointConfig.skipParquetFiles + .getOrElse(fail("skipParquetFiles were not written")) + + assertEquals( + skipFiles, + expectedProcessedFiles, + "Savepoint should contain the large file after all its partitions completed" + ) + + // Verify idempotency: running again should skip the file + val rowCountBefore = targetScylla().execute(selectAllStatement).all().size() + + performMigrationWithSmallPartitions() + + val rowCountAfter = targetScylla().execute(selectAllStatement).all().size() + + assertEquals( + rowCountBefore, + rowCountAfter, + "Second run should not duplicate data (file should be skipped)" + ) + } + + withTableAndSavepoints("multipartitiontest2", "multipartition2", "parquet-multipartition-test2").test( + "Mix of single-partition and multi-partition files tracked correctly" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("multipartition2") + Files.createDirectories(parquetDir) + + // Create one LARGE file (will be split) and two SMALL files (won't be split) + val files = List( + parquetDir.resolve("small-1.parquet") -> List( + TestRecord("small-1", "data", 100) + ), + parquetDir.resolve("large.parquet") -> (1 to 500).map { i => + TestRecord( + id = f"large-$i%05d", + foo = s"data-$i-" + ("x" * 150), + bar = i * 10 + ) + }.toList, + parquetDir.resolve("small-2.parquet") -> List( + TestRecord("small-2", "data", 200) + ) + ) + + files.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + val largeFileSize = Files.size(files(1)._1) + assert(largeFileSize > 65536, s"Large file should be > 64KB, got $largeFileSize") + + val expectedProcessedFiles = listDataFiles(parquetDir).map(toContainerParquetUri) + assertEquals(expectedProcessedFiles.size, 3, "Should have 3 parquet files total") + + // Run migration + performMigrationWithSmallPartitions(configFileName2) + + // Verify all data migrated + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val expectedRows = files.flatMap(_._2).map(row => row.id -> row).toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, expectedRows.size, "All rows from all files should be migrated") + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, expectedRows(id)) + } + } + + // Verify savepoint contains all 3 files + val savepointFile = findLatestSavepoint(savepointsDir) + .getOrElse(fail("Savepoint file was not created")) + + val savepointConfig = MigratorConfig.loadFrom(savepointFile.toString) + val skipFiles = savepointConfig.skipParquetFiles + .getOrElse(fail("skipParquetFiles were not written")) + + assertEquals( + skipFiles.size, + 3, + "Savepoint should contain all 3 files (both single and multi-partition)" + ) + + assertEquals( + skipFiles, + expectedProcessedFiles, + "All processed files should be in savepoint" + ) + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetParallelModeTest.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetParallelModeTest.scala new file mode 100644 index 00000000..a384d9a8 --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetParallelModeTest.scala @@ -0,0 +1,149 @@ +package com.scylladb.migrator.scylla + +import com.datastax.oss.driver.api.querybuilder.QueryBuilder +import com.scylladb.migrator.SparkUtils.successfullyPerformMigration +import com.scylladb.migrator.config.MigratorConfig +import java.nio.file.Files +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ +import scala.util.chaining._ + +class ParquetParallelModeTest extends ParquetMigratorSuite { + + private val configFileName: String = "parquet-to-scylla-parallel.yaml" + private val savepointsConfigFileName: String = "parquet-to-scylla-parallel-savepoints.yaml" + + override val munitTimeout: FiniteDuration = 2.minutes + + FunFixture + .map2(withTable("paralleltest"), withParquetDir("parallel")) + .test("Parallel mode migration with multiple Parquet files") { + case (tableName, parquetRoot) => + val parquetDir = parquetRoot.resolve("parallel") + Files.createDirectories(parquetDir) + + // Create multiple parquet files to test parallel processing + val parquetBatches = List( + parquetDir.resolve("batch-1.parquet") -> List( + TestRecord("1", "alpha", 10), + TestRecord("2", "beta", 20) + ), + parquetDir.resolve("batch-2.parquet") -> List( + TestRecord("3", "gamma", 30), + TestRecord("4", "delta", 40) + ), + parquetDir.resolve("batch-3.parquet") -> List( + TestRecord("5", "epsilon", 50) + ) + ) + + parquetBatches.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + // Run migration with parallel mode + successfullyPerformMigration(configFileName) + + // Verify all data was migrated + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val expectedRows = parquetBatches.flatMap(_._2).map(row => row.id -> row).toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, expectedRows.size) + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, expectedRows(id)) + } + } + } + + withTableAndSavepoints("paralleltest2", "parallel-savepoints", "parallel-savepoints-test").test( + "Parallel mode tracks all files in savepoints correctly" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("parallel-savepoints") + Files.createDirectories(parquetDir) + + // Create multiple parquet files to test parallel processing with savepoints + val parquetBatches = List( + parquetDir.resolve("file-1.parquet") -> List( + TestRecord("1", "alpha", 10), + TestRecord("2", "beta", 20) + ), + parquetDir.resolve("file-2.parquet") -> List( + TestRecord("3", "gamma", 30), + TestRecord("4", "delta", 40) + ), + parquetDir.resolve("file-3.parquet") -> List( + TestRecord("5", "epsilon", 50), + TestRecord("6", "zeta", 60) + ), + parquetDir.resolve("file-4.parquet") -> List( + TestRecord("7", "eta", 70) + ) + ) + + parquetBatches.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + val expectedProcessedFiles = listDataFiles(parquetDir).map(toContainerParquetUri) + + // Run migration with parallel mode and savepoints + successfullyPerformMigration(savepointsConfigFileName) + + // Verify all data was migrated correctly + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val expectedRows = parquetBatches.flatMap(_._2).map(row => row.id -> row).toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, expectedRows.size, "All rows should be migrated") + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, expectedRows(id)) + } + } + + // Verify savepoint was created + val savepointFile = findLatestSavepoint(savepointsDir) + .getOrElse(fail("Savepoint file was not created in parallel mode")) + + // Verify savepoint contains all processed files + val savepointConfig = MigratorConfig.loadFrom(savepointFile.toString) + val skipFiles = savepointConfig.skipParquetFiles + .getOrElse(fail("skipParquetFiles were not written in parallel mode")) + + assertEquals( + skipFiles, + expectedProcessedFiles, + "Savepoint should contain all processed files in parallel mode" + ) + + assertEquals(skipFiles.size, 4, "Should have tracked 4 parquet files") + + // Verify idempotency: run migration again, should skip all files + val rowCountBefore = targetScylla().execute(selectAllStatement).all().size() + + successfullyPerformMigration(savepointsConfigFileName) + + val rowCountAfter = targetScylla().execute(selectAllStatement).all().size() + + assertEquals( + rowCountBefore, + rowCountAfter, + "Second run should not insert duplicate data (idempotency check)" + ) + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetResumeIntegrationTest.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetResumeIntegrationTest.scala new file mode 100644 index 00000000..37dc78f5 --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetResumeIntegrationTest.scala @@ -0,0 +1,181 @@ +package com.scylladb.migrator.scylla + +import com.datastax.oss.driver.api.querybuilder.QueryBuilder +import com.scylladb.migrator.SparkUtils.successfullyPerformMigration +import com.scylladb.migrator.config.MigratorConfig + +import java.nio.file.Files +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ +import scala.util.chaining._ + +class ParquetResumeIntegrationTest extends ParquetMigratorSuite { + + override val munitTimeout: Duration = 120.seconds + + private val resumeConfig: String = "parquet-to-scylla-resume.yaml" + private val resumeAllProcessedConfig: String = "parquet-to-scylla-resume2.yaml" + + withTableAndSavepoints("resumetest", "resume", "parquet-resume-test").test( + "Resume migration after interruption skips already processed files" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("resume") + Files.createDirectories(parquetDir) + + // Phase 1: Create initial batch of files and migrate them + val firstBatch = List( + parquetDir.resolve("batch-1.parquet") -> List( + TestRecord("1", "alpha", 10), + TestRecord("2", "beta", 20) + ), + parquetDir.resolve("batch-2.parquet") -> List( + TestRecord("3", "gamma", 30), + TestRecord("4", "delta", 40) + ), + parquetDir.resolve("batch-3.parquet") -> List( + TestRecord("5", "epsilon", 50) + ) + ) + + firstBatch.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + // Run first migration + successfullyPerformMigration(resumeConfig) + + // Verify first batch was migrated + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val firstBatchRows = firstBatch.flatMap(_._2).map(row => row.id -> row).toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, firstBatchRows.size, "First batch should be fully migrated") + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, firstBatchRows(id)) + } + } + + // Load savepoint and verify it contains all first batch files + val savepointAfterFirstRun = findLatestSavepoint(savepointsDir) + .getOrElse(fail("Savepoint file was not created after first run")) + + val configAfterFirstRun = MigratorConfig.loadFrom(savepointAfterFirstRun.toString) + val skipFilesAfterFirstRun = configAfterFirstRun.skipParquetFiles + .getOrElse(fail("skipParquetFiles were not written")) + + val expectedProcessedFiles = listDataFiles(parquetDir).map(toContainerParquetUri) + assertEquals( + skipFilesAfterFirstRun, + expectedProcessedFiles, + "Savepoint should contain all processed files from first batch" + ) + + // Phase 2: Create second batch of files (simulating new data) + val secondBatch = List( + parquetDir.resolve("batch-4.parquet") -> List( + TestRecord("6", "zeta", 60), + TestRecord("7", "eta", 70) + ), + parquetDir.resolve("batch-5.parquet") -> List( + TestRecord("8", "theta", 80) + ) + ) + + secondBatch.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + // Phase 3: Run migration again (should skip first batch, process only second batch) + successfullyPerformMigration(resumeConfig) + + // Verify all data is present (first + second batch) + val allExpectedRows = (firstBatch ++ secondBatch) + .flatMap(_._2) + .map(row => row.id -> row) + .toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, allExpectedRows.size, "All data from both batches should be present") + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, allExpectedRows(id)) + } + } + + // Verify final savepoint includes all files + val savepointAfterSecondRun = findLatestSavepoint(savepointsDir) + .getOrElse(fail("Savepoint file was not created after second run")) + + val configAfterSecondRun = MigratorConfig.loadFrom(savepointAfterSecondRun.toString) + val skipFilesAfterSecondRun = configAfterSecondRun.skipParquetFiles + .getOrElse(fail("skipParquetFiles were not written")) + + val allProcessedFiles = listDataFiles(parquetDir).map(toContainerParquetUri) + assertEquals( + skipFilesAfterSecondRun, + allProcessedFiles, + "Final savepoint should contain all files from both batches" + ) + + // Verify count matches expectation + assertEquals(allProcessedFiles.size, 5, "Should have 5 total parquet files") + } + + withTableAndSavepoints("resumetest2", "resume2", "parquet-resume-test2").test( + "Resume when all files already processed performs no work" + ) { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("resume2") + Files.createDirectories(parquetDir) + + val parquetBatches = List( + parquetDir.resolve("file-1.parquet") -> List( + TestRecord("1", "test1", 100) + ), + parquetDir.resolve("file-2.parquet") -> List( + TestRecord("2", "test2", 200) + ) + ) + + parquetBatches.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + // First run: migrate all files + successfullyPerformMigration(resumeAllProcessedConfig) + + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val initialRowCount = targetScylla().execute(selectAllStatement).all().size() + assertEquals(initialRowCount, 2, "Initial migration should have 2 rows") + + // Second run: should detect all files already processed and do nothing + successfullyPerformMigration(resumeAllProcessedConfig) + + // Verify data unchanged (no duplicates) + val finalRowCount = targetScylla().execute(selectAllStatement).all().size() + assertEquals(finalRowCount, 2, "Should still have exactly 2 rows (no duplicates)") + + // Verify savepoint still tracks all files + val finalSavepoint = findLatestSavepoint(savepointsDir) + .getOrElse(fail("Savepoint should exist")) + + val finalConfig = MigratorConfig.loadFrom(finalSavepoint.toString) + val finalSkipFiles = finalConfig.skipParquetFiles.getOrElse(Set.empty) + + assertEquals(finalSkipFiles.size, 2, "Should still skip 2 files") + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetSavepointsIntegrationTest.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetSavepointsIntegrationTest.scala new file mode 100644 index 00000000..83c32ff0 --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetSavepointsIntegrationTest.scala @@ -0,0 +1,65 @@ +package com.scylladb.migrator.scylla + +import com.datastax.oss.driver.api.querybuilder.QueryBuilder +import com.scylladb.migrator.SparkUtils.successfullyPerformMigration +import com.scylladb.migrator.config.MigratorConfig + +import java.nio.file.Files +import scala.jdk.CollectionConverters._ +import scala.util.chaining._ + +class ParquetSavepointsIntegrationTest extends ParquetMigratorSuite { + + + private val configFileName: String = "parquet-to-scylla-savepoints.yaml" + + withTableAndSavepoints("savepointstest", "savepoints", "parquet-savepoints-test").test("Parquet savepoints include all processed files") { case (tableName, (parquetRoot, savepointsDir)) => + + val parquetDir = parquetRoot.resolve("savepoints") + Files.createDirectories(parquetDir) + + val parquetBatches = List( + parquetDir.resolve("batch-1.parquet") -> List( + TestRecord("1", "alpha", 10), + TestRecord("2", "beta", 20) + ), + parquetDir.resolve("batch-2.parquet") -> List( + TestRecord("3", "gamma", 30) + ) + ) + + parquetBatches.foreach { case (path, rows) => + writeParquetTestFile(path, rows) + } + + val expectedProcessedFiles = listDataFiles(parquetDir).map(toContainerParquetUri) + + successfullyPerformMigration(configFileName) + + val selectAllStatement = QueryBuilder + .selectFrom(keyspace, tableName) + .all() + .build() + + val expectedRows = parquetBatches.flatMap(_._2).map(row => row.id -> row).toMap + + targetScylla().execute(selectAllStatement).tap { resultSet => + val rows = resultSet.all().asScala + assertEquals(rows.size, expectedRows.size) + rows.foreach { row => + val id = row.getString("id") + val migrated = TestRecord(id, row.getString("foo"), row.getInt("bar")) + assertEquals(migrated, expectedRows(id)) + } + } + + val savepointFile = findLatestSavepoint(savepointsDir) + .getOrElse(fail("Savepoint file was not created")) + + val savepointConfig = MigratorConfig.loadFrom(savepointFile.toString) + val skipFiles = savepointConfig.skipParquetFiles.getOrElse(fail("skipParquetFiles were not written")) + + assertEquals(skipFiles, expectedProcessedFiles) + } + +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetSavepointsTest.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetSavepointsTest.scala new file mode 100644 index 00000000..27c98f87 --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetSavepointsTest.scala @@ -0,0 +1,205 @@ +package com.scylladb.migrator.scylla + +import com.scylladb.migrator.config.{MigratorConfig, SourceSettings} +import com.scylladb.migrator.readers.{Parquet, ParquetSavepointsManager} +import org.apache.spark.sql.SparkSession + +import java.nio.file.Files + +class ParquetSavepointsTest extends munit.FunSuite { + + implicit val spark: SparkSession = SparkSession + .builder() + .appName("ParquetSavepointsTest") + .master("local[*]") + .config("spark.sql.shuffle.partitions", "1") + .getOrCreate() + + override def afterAll(): Unit = { + spark.stop() + super.afterAll() + } + + test("Parquet file listing with single file") { + val tempDir = Files.createTempDirectory("parquet-test") + val tempFile = tempDir.resolve("test.parquet") + + try { + import spark.implicits._ + val testData = Seq(("1", "test")).toDF("id", "name") + testData.write.parquet(tempFile.toString) + + val files = Parquet.listParquetFiles(spark, tempFile.toString) + assert(files.size >= 1) + assert(files.head.contains("part-") && files.head.endsWith(".parquet")) + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("Parquet file listing with directory") { + val tempDir = Files.createTempDirectory("parquet-dir-test") + + try { + import spark.implicits._ + val testData1 = Seq(("1", "test1")).toDF("id", "name") + val testData2 = Seq(("2", "test2")).toDF("id", "name") + + val file1Path = tempDir.resolve("file1.parquet") + val file2Path = tempDir.resolve("file2.parquet") + + testData1.write.parquet(file1Path.toString) + testData2.write.parquet(file2Path.toString) + + val files = Parquet.listParquetFiles(spark, tempDir.toString) + assert(files.size >= 2) + files.foreach(file => { + assert(file.contains("part-") && file.endsWith(".parquet")) + }) + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("ParquetSavepointsManager initialization and state") { + val tempDir = Files.createTempDirectory("savepoints-test") + + try { + val config = MigratorConfig( + source = SourceSettings.Parquet("dummy", None, None, None), + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, tempDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = Some(Set("file1.parquet")), + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + val initialState = manager.describeMigrationState() + assert(initialState.contains("Processed files: 1")) + + val updatedConfig = manager.updateConfigWithMigrationState() + assertEquals(updatedConfig.skipParquetFiles.get, Set("file1.parquet")) + + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("Parquet file filtering with skipFiles") { + val tempDir = Files.createTempDirectory("parquet-skip-test") + + try { + import spark.implicits._ + + val testData1 = Seq(("1", "data1")).toDF("id", "name") + val testData2 = Seq(("2", "data2")).toDF("id", "name") + + val file1Path = tempDir.resolve("file1.parquet") + val file2Path = tempDir.resolve("file2.parquet") + + testData1.write.parquet(file1Path.toString) + testData2.write.parquet(file2Path.toString) + + // List all files + val allFiles = Parquet.listParquetFiles(spark, tempDir.toString) + assert(allFiles.length >= 2) + + // Test filtering with skipFiles + val fileToSkip = allFiles.head + val skipFiles = Set(fileToSkip) + val filesToProcess = allFiles.filterNot(skipFiles.contains) + + assertEquals(filesToProcess.size, allFiles.size - 1) + assert(!filesToProcess.contains(fileToSkip)) + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } + + test("ParquetSavepointsManager tracks files during per-file processing") { + val tempDir = Files.createTempDirectory("savepoints-tracking-test") + val savepointsDir = Files.createTempDirectory("savepoints-output") + + try { + import spark.implicits._ + + val testData1 = Seq((1, "data1")).toDF("id", "name") + val testData2 = Seq((2, "data2")).toDF("id", "name") + val testData3 = Seq((3, "data3")).toDF("id", "name") + + val file1Path = tempDir.resolve("file1.parquet") + val file2Path = tempDir.resolve("file2.parquet") + val file3Path = tempDir.resolve("file3.parquet") + + testData1.write.parquet(file1Path.toString) + testData2.write.parquet(file2Path.toString) + testData3.write.parquet(file3Path.toString) + + val parquetSource = SourceSettings.Parquet(tempDir.toString, None, None, None) + + val allFiles = Parquet.listParquetFiles(spark, tempDir.toString) + + val config = MigratorConfig( + source = parquetSource, + target = null, + renames = None, + savepoints = com.scylladb.migrator.config.Savepoints(300, savepointsDir.toString), + skipTokenRanges = None, + skipSegments = None, + skipParquetFiles = None, + validation = None + ) + + val manager = ParquetSavepointsManager(config, spark.sparkContext) + + try { + assertEquals(manager.updateConfigWithMigrationState().skipParquetFiles.getOrElse(Set.empty).size, 0) + + allFiles.zipWithIndex.foreach { case (filePath, index) => + val singleFileDF = spark.read.parquet(filePath) + val count = singleFileDF.count() + assert(count >= 1) + + manager.markFileAsProcessed(filePath) + val processedSoFar = manager.updateConfigWithMigrationState().skipParquetFiles.getOrElse(Set.empty) + assertEquals(processedSoFar.size, index + 1, s"After processing file ${index + 1}") + } + + val finalProcessed = manager.updateConfigWithMigrationState().skipParquetFiles.getOrElse(Set.empty) + assertEquals(finalProcessed.size, allFiles.size) + assert(finalProcessed == allFiles.toSet) + + } finally { + manager.close() + } + + } finally { + Files.walk(tempDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + Files.walk(savepointsDir) + .sorted(java.util.Comparator.reverseOrder()) + .forEach(Files.delete) + } + } +} diff --git a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetToScyllaBasicMigrationTest.scala b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetToScyllaBasicMigrationTest.scala index 8e740511..c5105cee 100644 --- a/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetToScyllaBasicMigrationTest.scala +++ b/tests/src/test/scala/com/scylladb/migrator/scylla/ParquetToScyllaBasicMigrationTest.scala @@ -1,79 +1,33 @@ package com.scylladb.migrator.scylla -import com.datastax.oss.driver.api.core.CqlSession -import com.datastax.oss.driver.api.querybuilder.{ QueryBuilder, SchemaBuilder } -import com.github.mjakubowski84.parquet4s.ParquetWriter -import com.scylladb.migrator.CassandraUtils.dropAndRecreateTable +import com.datastax.oss.driver.api.querybuilder.QueryBuilder import com.scylladb.migrator.SparkUtils.successfullyPerformMigration -import com.scylladb.migrator.scylla.ParquetToScyllaBasicMigrationTest.BasicTestSchema -import org.apache.parquet.hadoop.ParquetFileWriter -import java.net.InetSocketAddress +import java.nio.file.{Files, Path} import scala.jdk.CollectionConverters._ import scala.util.chaining._ -class ParquetToScyllaBasicMigrationTest extends munit.FunSuite { +class ParquetToScyllaBasicMigrationTest extends ParquetMigratorSuite { - test("Basic migration from Parquet to ScyllaDB") { - val keyspace = "test" - val tableName = "BasicTest" - - val targetScylla: CqlSession = CqlSession - .builder() - .addContactPoint(new InetSocketAddress("localhost", 9042)) - .withLocalDatacenter("datacenter1") - .withAuthCredentials("dummy", "dummy") - .build() - - val keyspaceStatement = - SchemaBuilder - .createKeyspace(keyspace) - .ifNotExists() - .withReplicationOptions(Map[String, AnyRef]( - "class" -> "SimpleStrategy", - "replication_factor" -> new Integer(1)).asJava) - .build() - targetScylla.execute(keyspaceStatement) - - // Create the Parquet data source - ParquetWriter.writeAndClose( - "docker/parquet/basic.parquet", - List(BasicTestSchema(id = "12345", foo = "bar")), - ParquetWriter.Options(writeMode = ParquetFileWriter.Mode.OVERWRITE) + FunFixture.map2(withTable("BasicTest"), withParquetDir("root")).test("Basic migration from Parquet to ScyllaDB") { case (tableName, parquetRoot) => + writeParquetTestFile( + parquetRoot.resolve("basic.parquet"), + List(TestRecord(id = "12345", foo = "bar", bar = 0)) ) - // Create the target table in the target database - dropAndRecreateTable(targetScylla, keyspace, tableName, identity) - - // Perform the migration successfullyPerformMigration("parquet-to-scylla-basic.yaml") - // Check that the item has been migrated to the target table val selectAllStatement = QueryBuilder .selectFrom(keyspace, tableName) .all() .build() - targetScylla.execute(selectAllStatement).tap { resultSet => + targetScylla().execute(selectAllStatement).tap { resultSet => val rows = resultSet.all().asScala assertEquals(rows.size, 1) val row = rows.head assertEquals(row.getString("id"), "12345") assertEquals(row.getString("foo"), "bar") } - - // Clean the target table - val dropTableQuery = SchemaBuilder.dropTable(keyspace, tableName).build() - targetScylla.execute(dropTableQuery) - // Close the database driver - targetScylla.close() } } - -object ParquetToScyllaBasicMigrationTest { - - // parquet4s automatically derives the Parquet schema from this class definition. - // It must be consistent with the definition of the table from `dropAndRecreateTable`. - case class BasicTestSchema(id: String, foo: String) - -}