Skip to content

Commit 076ccad

Browse files
author
Artem
committed
Replace Parquet strategy pattern with unified parallel approach
1 parent 9712e8c commit 076ccad

22 files changed

+1387
-319
lines changed

migrator/src/main/scala/com/scylladb/migrator/config/Savepoints.scala

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,7 @@ package com.scylladb.migrator.config
33
import io.circe.{ Decoder, Encoder }
44
import io.circe.generic.semiauto.{ deriveDecoder, deriveEncoder }
55

6-
sealed trait ParquetProcessingMode
7-
object ParquetProcessingMode {
8-
case object Parallel extends ParquetProcessingMode
9-
case object Sequential extends ParquetProcessingMode
10-
11-
implicit val encoder: Encoder[ParquetProcessingMode] = Encoder.encodeString.contramap {
12-
case Parallel => "parallel"
13-
case Sequential => "sequential"
14-
}
15-
16-
implicit val decoder: Decoder[ParquetProcessingMode] = Decoder.decodeString.emap {
17-
case "parallel" => Right(Parallel)
18-
case "sequential" => Right(Sequential)
19-
case other =>
20-
Left(s"Unknown parquet processing mode: $other. Valid values: parallel, sequential")
21-
}
22-
}
23-
24-
case class Savepoints(intervalSeconds: Int,
25-
path: String,
26-
parquetProcessingMode: Option[ParquetProcessingMode]) {
27-
28-
/**
29-
* Returns the configured Parquet processing mode.
30-
* Defaults to [[ParquetProcessingMode.Parallel]] if not specified.
31-
* This default affects migration semantics.
32-
*/
33-
def getParquetProcessingMode: ParquetProcessingMode =
34-
parquetProcessingMode.getOrElse(ParquetProcessingMode.Parallel)
35-
}
6+
case class Savepoints(intervalSeconds: Int, path: String)
367

378
object Savepoints {
389
implicit val encoder: Encoder[Savepoints] = deriveEncoder[Savepoints]
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package com.scylladb.migrator.readers
2+
3+
import org.apache.log4j.LogManager
4+
import org.apache.spark.scheduler.{ SparkListener, SparkListenerTaskEnd }
5+
import org.apache.spark.Success
6+
7+
import scala.collection.concurrent.TrieMap
8+
9+
/**
10+
* SparkListener that tracks partition completion and aggregates it to file-level completion.
11+
*
12+
* This listener monitors Spark task completion events and maintains mappings between
13+
* partitions and files. When all partitions belonging to a file have been successfully
14+
* completed, it marks the file as processed via the ParquetSavepointsManager.
15+
*
16+
* @param partitionToFile Mapping from Spark partition ID to source file paths
17+
* @param fileToPartitions Mapping from file path to the set of partition IDs reading from it
18+
* @param savepointsManager Manager to notify when files are completed
19+
*/
20+
class FileCompletionListener(
21+
partitionToFiles: Map[Int, Set[String]],
22+
fileToPartitions: Map[String, Set[Int]],
23+
savepointsManager: ParquetSavepointsManager
24+
) extends SparkListener {
25+
26+
private val log = LogManager.getLogger("com.scylladb.migrator.readers.FileCompletionListener")
27+
28+
private val completedPartitions = TrieMap.empty[Int, Boolean]
29+
30+
private val completedFiles = TrieMap.empty[String, Boolean]
31+
32+
log.info(
33+
s"FileCompletionListener initialized: tracking ${fileToPartitions.size} files " +
34+
s"across ${partitionToFiles.size} partitions")
35+
36+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit =
37+
if (taskEnd.reason == Success) {
38+
val partitionId = taskEnd.taskInfo.partitionId
39+
40+
partitionToFiles.get(partitionId) match {
41+
case Some(filenames) =>
42+
if (completedPartitions.putIfAbsent(partitionId, true).isEmpty) {
43+
filenames.foreach { filename =>
44+
log.debug(s"Partition $partitionId completed (file: $filename)")
45+
checkFileCompletion(filename)
46+
}
47+
}
48+
49+
case None =>
50+
log.trace(s"Task completed for untracked partition $partitionId")
51+
}
52+
} else {
53+
log.debug(
54+
s"Task for partition ${taskEnd.taskInfo.partitionId} did not complete successfully: ${taskEnd.reason}")
55+
}
56+
57+
private def checkFileCompletion(filename: String): Unit = {
58+
if (completedFiles.contains(filename)) {
59+
return
60+
}
61+
62+
fileToPartitions.get(filename) match {
63+
case Some(allPartitions) =>
64+
val allComplete = allPartitions.forall(completedPartitions.contains)
65+
66+
if (allComplete) {
67+
if (completedFiles.putIfAbsent(filename, true).isEmpty) {
68+
savepointsManager.markFileAsProcessed(filename)
69+
70+
val progress = s"${completedFiles.size}/${fileToPartitions.size}"
71+
log.info(s"File completed: $filename (progress: $progress)")
72+
}
73+
} else {
74+
val completedCount = allPartitions.count(completedPartitions.contains)
75+
log.trace(s"File $filename: $completedCount/${allPartitions.size} partitions complete")
76+
}
77+
78+
case None =>
79+
log.warn(s"File $filename not found in fileToPartitions map (this shouldn't happen)")
80+
}
81+
}
82+
83+
def getCompletedFilesCount: Int = completedFiles.size
84+
85+
def getTotalFilesCount: Int = fileToPartitions.size
86+
87+
def getProgressReport: String = {
88+
val filesCompleted = getCompletedFilesCount
89+
val totalFiles = getTotalFilesCount
90+
91+
s"Progress: $filesCompleted/$totalFiles files"
92+
}
93+
}

migrator/src/main/scala/com/scylladb/migrator/readers/ParallelParquetStrategy.scala

Lines changed: 0 additions & 40 deletions
This file was deleted.

migrator/src/main/scala/com/scylladb/migrator/readers/Parquet.scala

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,73 @@
11
package com.scylladb.migrator.readers
22

3-
import com.scylladb.migrator.config.{
4-
MigratorConfig,
5-
ParquetProcessingMode,
6-
SourceSettings,
7-
TargetSettings
8-
}
9-
import com.scylladb.migrator.scylla.SourceDataFrame
10-
import com.scylladb.migrator.scylla
3+
import com.scylladb.migrator.config.{ MigratorConfig, SourceSettings, TargetSettings }
4+
import com.scylladb.migrator.scylla.{ ScyllaParquetMigrator, SourceDataFrame }
115
import org.apache.log4j.LogManager
126
import org.apache.spark.sql.{ AnalysisException, SparkSession }
137
import scala.util.Using
148

15-
case class ParquetReaderWithSavepoints(source: SourceSettings.Parquet,
16-
allFiles: Seq[String],
17-
skipFiles: Set[String]) {
18-
19-
val filesToProcess: Seq[String] = allFiles.filterNot(skipFiles.contains)
20-
21-
def configureHadoop(spark: SparkSession): Unit =
22-
Parquet.configureHadoopCredentials(spark, source)
23-
}
24-
259
object Parquet {
2610
val log = LogManager.getLogger("com.scylladb.migrator.readers.Parquet")
2711

2812
def migrateToScylla(config: MigratorConfig,
2913
source: SourceSettings.Parquet,
3014
target: TargetSettings.Scylla)(implicit spark: SparkSession): Unit = {
31-
val processingMode = config.savepoints.getParquetProcessingMode
32-
33-
val strategy: ParquetProcessingStrategy = processingMode match {
34-
case ParquetProcessingMode.Parallel =>
35-
log.info("Selected PARALLEL processing mode (default)")
36-
new ParallelParquetStrategy()
37-
case ParquetProcessingMode.Sequential =>
38-
log.info("Selected SEQUENTIAL processing mode (with savepoints)")
39-
new SequentialParquetStrategy()
40-
}
41-
42-
strategy.migrate(config, source, target)
43-
}
44-
45-
def prepareParquetReader(spark: SparkSession,
46-
source: SourceSettings.Parquet,
47-
skipFiles: Set[String] = Set.empty): ParquetReaderWithSavepoints = {
15+
log.info("Starting Parquet migration with parallel processing and file-level savepoints")
4816

4917
configureHadoopCredentials(spark, source)
5018

5119
val allFiles = listParquetFiles(spark, source.path)
52-
log.info(s"Found ${allFiles.size} Parquet files in ${source.path}")
20+
val skipFiles = config.getSkipParquetFilesOrEmptySet
21+
val filesToProcess = allFiles.filterNot(skipFiles.contains)
22+
23+
if (filesToProcess.isEmpty) {
24+
log.info("No Parquet files to process. Migration is complete.")
25+
return
26+
}
27+
28+
log.info(s"Processing ${filesToProcess.size} Parquet files")
5329

54-
if (skipFiles.nonEmpty) {
55-
log.info(s"Skipping ${skipFiles.size} already processed files")
30+
val df = if (skipFiles.isEmpty) {
31+
spark.read.parquet(source.path)
32+
} else {
33+
spark.read.parquet(filesToProcess: _*)
5634
}
5735

58-
ParquetReaderWithSavepoints(source, allFiles, skipFiles)
36+
log.info("Reading partition metadata for file tracking...")
37+
val metadata = PartitionMetadataReader.readMetadataFromDataFrame(df)
38+
39+
val partitionToFiles = PartitionMetadataReader.buildPartitionToFileMap(metadata)
40+
val fileToPartitions = PartitionMetadataReader.buildFileToPartitionsMap(metadata)
41+
42+
log.info(
43+
s"Discovered ${fileToPartitions.size} files with ${metadata.size} total partitions to process")
44+
45+
Using.resource(ParquetSavepointsManager(config, spark.sparkContext)) { savepointsManager =>
46+
val listener = new FileCompletionListener(
47+
partitionToFiles,
48+
fileToPartitions,
49+
savepointsManager
50+
)
51+
spark.sparkContext.addSparkListener(listener)
52+
53+
try {
54+
val sourceDF = SourceDataFrame(df, None, savepointsSupported = false)
55+
56+
log.info("Created DataFrame from Parquet source")
57+
58+
ScyllaParquetMigrator.migrate(config, target, sourceDF, savepointsManager)
59+
60+
savepointsManager.dumpMigrationState("completed")
61+
62+
log.info(
63+
s"Parquet migration completed successfully: " +
64+
s"${listener.getCompletedFilesCount}/${listener.getTotalFilesCount} files processed")
65+
66+
} finally {
67+
spark.sparkContext.removeSparkListener(listener)
68+
log.info(s"Final progress: ${listener.getProgressReport}")
69+
}
70+
}
5971
}
6072

6173
def listParquetFiles(spark: SparkSession, path: String): Seq[String] = {
@@ -88,7 +100,7 @@ object Parquet {
88100
* This method sets the necessary Hadoop configuration properties for AWS access key, secret key,
89101
* and optionally a session token. When a session token is present, it sets the credentials provider
90102
* to TemporaryAWSCredentialsProvider as required by Hadoop.
91-
*
103+
*
92104
* If a region is specified in the source configuration, this method also sets the S3A endpoint region
93105
* via the `fs.s3a.endpoint.region` property.
94106
*

migrator/src/main/scala/com/scylladb/migrator/readers/ParquetProcessingStrategy.scala

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package com.scylladb.migrator.readers
2+
3+
import org.apache.log4j.LogManager
4+
import org.apache.spark.sql.{ DataFrame, SparkSession }
5+
import org.apache.spark.sql.functions._
6+
7+
case class PartitionMetadata(
8+
partitionId: Int,
9+
filename: String
10+
)
11+
12+
/**
13+
* This reader uses Spark's internal partition information to build mappings
14+
* between partition IDs and file paths. This allows us to track when all
15+
* partitions of a file have been processed, enabling file-level savepointse.
16+
*/
17+
object PartitionMetadataReader {
18+
private val logger = LogManager.getLogger("com.scylladb.migrator.readers.PartitionMetadataReader")
19+
20+
def readMetadata(spark: SparkSession, filePaths: Seq[String]): Seq[PartitionMetadata] = {
21+
logger.info(s"Reading partition metadata from ${filePaths.size} file(s)")
22+
val df = spark.read.parquet(filePaths: _*)
23+
readMetadataFromDataFrame(df)
24+
}
25+
26+
def readMetadataFromDataFrame(df: DataFrame): Seq[PartitionMetadata] =
27+
try {
28+
29+
val partitionInfo = df
30+
.select(input_file_name().as("filename"))
31+
.rdd
32+
.mapPartitionsWithIndex { (partitionId, iter) =>
33+
val files = iter.map(row => row.getString(0)).toSet
34+
files.map(filename => (partitionId, filename)).iterator
35+
}
36+
.collect()
37+
.toSeq
38+
39+
val metadata = partitionInfo.zipWithIndex.map {
40+
case ((partitionId, filename), idx) =>
41+
PartitionMetadata(
42+
partitionId = partitionId,
43+
filename = filename
44+
)
45+
}
46+
47+
logger.info(s"Discovered ${metadata.size} partition-to-file mappings")
48+
49+
val fileStats = metadata.groupBy(_.filename).view.mapValues(_.size)
50+
logger.info(s"Files distribution: ${fileStats.size} unique files")
51+
fileStats.foreach {
52+
case (file, partCount) =>
53+
logger.debug(s" File: $file -> $partCount partition(s)")
54+
}
55+
56+
metadata
57+
58+
} catch {
59+
case e: Exception =>
60+
logger.error(s"Failed to read partition metadata", e)
61+
throw new RuntimeException(s"Could not read partition metadata: ${e.getMessage}", e)
62+
}
63+
64+
def buildFileToPartitionsMap(metadata: Seq[PartitionMetadata]): Map[String, Set[Int]] = {
65+
val result = metadata
66+
.groupBy(_.filename)
67+
.view
68+
.mapValues(_.map(_.partitionId).toSet)
69+
.toMap
70+
71+
logger.debug(s"Built file-to-partitions map with ${result.size} files")
72+
result
73+
}
74+
75+
def buildPartitionToFileMap(metadata: Seq[PartitionMetadata]): Map[Int, Set[String]] = {
76+
val result = metadata
77+
.groupBy(_.partitionId)
78+
.view
79+
.mapValues(_.map(_.filename).toSet)
80+
.toMap
81+
82+
logger.debug(s"Built partition-to-file map with ${result.size} partitions")
83+
result
84+
}
85+
}

0 commit comments

Comments
 (0)