diff --git a/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergFileIO.java b/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergFileIO.java
index 39c6b325133..c7c8461ddbb 100644
--- a/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergFileIO.java
+++ b/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergFileIO.java
@@ -18,6 +18,7 @@
import com.nvidia.spark.rapids.jni.fileio.RapidsFileIO;
import com.nvidia.spark.rapids.jni.fileio.RapidsInputFile;
+import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
import org.apache.iceberg.io.FileIO;
import java.io.IOException;
@@ -48,4 +49,9 @@ public IcebergFileIO(FileIO delegate) {
public IcebergInputFile newInputFile(String path) throws IOException {
return new IcebergInputFile(delegate.newInputFile(path));
}
+
+ @Override
+ public IcebergOutputFile newOutputFile(String path) throws IOException {
+ return new IcebergOutputFile(delegate.newOutputFile(path));
+ }
}
diff --git a/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputFile.java b/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputFile.java
new file mode 100644
index 00000000000..56dda3284e7
--- /dev/null
+++ b/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputFile.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.fileio.iceberg;
+
+import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
+import org.apache.iceberg.io.OutputFile;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * Implementation of {@link RapidsOutputFile} using Iceberg {@link OutputFile}.
+ */
+public class IcebergOutputFile implements RapidsOutputFile {
+ private final OutputFile delegate;
+
+ public IcebergOutputFile(OutputFile delegate) {
+ Objects.requireNonNull(delegate, "delegate can't be null");
+ this.delegate = delegate;
+ }
+
+ @Override
+ public IcebergOutputStream create(boolean overwrite) throws IOException {
+ if (overwrite) {
+ return new IcebergOutputStream(delegate.createOrOverwrite());
+ }
+ return new IcebergOutputStream(delegate.create());
+ }
+
+ @Override
+ public String getPath() {
+ return delegate.location();
+ }
+}
diff --git a/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputStream.java b/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputStream.java
new file mode 100644
index 00000000000..ee2a34f8334
--- /dev/null
+++ b/iceberg/src/main/java/com/nvidia/spark/rapids/fileio/iceberg/IcebergOutputStream.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.fileio.iceberg;
+
+import com.nvidia.spark.rapids.jni.fileio.RapidsOutputStream;
+import org.apache.iceberg.io.PositionOutputStream;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * A {@link RapidsOutputStream} implementation that wraps an Iceberg {@link PositionOutputStream}.
+ */
+public class IcebergOutputStream extends RapidsOutputStream {
+ private final PositionOutputStream out;
+ private boolean closed;
+
+ public IcebergOutputStream(PositionOutputStream out) {
+ this.out = requireNonNull(out, "out can't be null");
+ this.closed = false;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ out.write(b);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ out.write(b, off, len);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ out.flush();
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (!closed) {
+ out.close();
+ closed = true;
+ }
+ }
+}
diff --git a/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkFileWriterFactory.scala b/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkFileWriterFactory.scala
index 63055d6cf79..e5c5fd61351 100644
--- a/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkFileWriterFactory.scala
+++ b/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkFileWriterFactory.scala
@@ -38,6 +38,7 @@ class GpuSparkFileWriterFactory(val table: Table,
val columnarOutputWriterFactory: ColumnarOutputWriterFactory,
val taskStatsTracker: ColumnarWriteTaskStatsTracker,
val hadoopConf: SerializableConfiguration,
+ val fileIO: IcebergFileIO
) extends FileWriterFactory[SpillableColumnarBatch] {
require(dataFileFormat == FileFormat.PARQUET,
s"GpuSparkFileWriterFactory only supports PARQUET file format, but got $dataFileFormat")
@@ -78,8 +79,8 @@ class GpuSparkFileWriterFactory(val table: Table,
dataSchema = dataSparkType,
context = taskAttemptContext,
statsTrackers = Seq(taskStatsTracker),
- debugOutputPath = None
- ).asInstanceOf[GpuParquetWriter]
+ debugOutputPath = None,
+ fileIO).asInstanceOf[GpuParquetWriter]
new GpuIcebergParquetAppender(
gpuWriter,
diff --git a/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkWrite.scala b/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkWrite.scala
index 751d79f83d3..5bee0e2f391 100644
--- a/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkWrite.scala
+++ b/iceberg/src/main/scala/org/apache/iceberg/spark/source/GpuSparkWrite.scala
@@ -22,10 +22,9 @@ import scala.util.{Failure, Success}
import com.nvidia.spark.rapids.{ColumnarOutputWriterFactory, GpuParquetFileFormat, GpuWrite, SparkPlanMeta, SpillableColumnarBatch}
import com.nvidia.spark.rapids.Arm.closeOnExcept
import com.nvidia.spark.rapids.SpillPriorities.ACTIVE_ON_DECK_PRIORITY
+import com.nvidia.spark.rapids.fileio.iceberg.IcebergFileIO
import com.nvidia.spark.rapids.iceberg.GpuIcebergPartitioner
-import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.Job
-import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.shaded.org.apache.commons.lang3.reflect.{FieldUtils, MethodUtils}
import org.apache.iceberg.{DataFile, FileFormat, PartitionSpec, Schema, SerializableTable, SnapshotUpdate, Table}
import org.apache.iceberg.io.{DataWriteResult, FileIO, GpuClusteredDataWriter, GpuFanoutDataWriter, GpuRollingDataWriter, OutputFileFactory, PartitioningWriter}
@@ -96,7 +95,6 @@ class GpuSparkWrite(cpu: SparkWrite) extends GpuWrite with RequiresDistributionA
val tmpJob = Job.getInstance(hadoopConf)
tmpJob.setOutputKeyClass(classOf[Void])
tmpJob.setOutputValueClass(classOf[InternalRow])
- FileOutputFormat.setOutputPath(tmpJob, new Path(table.location()))
tmpJob
}
@@ -202,6 +200,8 @@ class GpuWriterFactory(val tableBroadcast: Broadcast[Table],
val hadoopConf: SerializableConfiguration
) extends DataWriterFactory {
+ private lazy val fileIO: IcebergFileIO = new IcebergFileIO(tableBroadcast.value.io())
+
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
val table = tableBroadcast.value
val spec = table.specs().get(outputSpecId)
@@ -221,7 +221,8 @@ class GpuWriterFactory(val tableBroadcast: Broadcast[Table],
format,
outputWriterFactory,
statsTracker.newTaskInstance(),
- hadoopConf)
+ hadoopConf,
+ fileIO)
if (spec.isUnpartitioned) {
new GpuUnpartitionedDataWriter(writerFactory, outputFileFactory, io, spec, targetFileSize)
diff --git a/integration_tests/src/main/python/iceberg/iceberg_append_test.py b/integration_tests/src/main/python/iceberg/iceberg_append_test.py
index 5c60fb5505e..b8f69ff9237 100644
--- a/integration_tests/src/main/python/iceberg/iceberg_append_test.py
+++ b/integration_tests/src/main/python/iceberg/iceberg_append_test.py
@@ -23,12 +23,8 @@
from marks import iceberg, ignore_order, allow_non_gpu
from spark_session import with_gpu_session, with_cpu_session, is_spark_35x
-pytestmark = [
- pytest.mark.skipif(not is_spark_35x(),
- reason="Current spark-rapids only support spark 3.5.x"),
- pytest.mark.skipif(is_iceberg_remote_catalog(),
- reason="https://github.com/NVIDIA/spark-rapids/issues/13471")
-]
+pytestmark = pytest.mark.skipif(not is_spark_35x(),
+ reason="Current spark-rapids only support spark 3.5.x")
def do_test_insert_into_table_sql(spark_tmp_table_factory,
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopFileIO.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopFileIO.java
index 870b6098b4b..da45700ba31 100644
--- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopFileIO.java
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopFileIO.java
@@ -18,7 +18,9 @@
import com.nvidia.spark.rapids.jni.fileio.RapidsFileIO;
import com.nvidia.spark.rapids.jni.fileio.RapidsInputFile;
+import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.util.SerializableConfiguration;
@@ -38,12 +40,18 @@ public HadoopFileIO(Configuration hadoopConf) {
}
@Override
- public RapidsInputFile newInputFile(String path) throws IOException {
+ public HadoopInputFile newInputFile(String path) throws IOException {
return this.newInputFile(new Path(path));
}
@Override
- public RapidsInputFile newInputFile(Path path) throws IOException {
+ public HadoopInputFile newInputFile(Path path) throws IOException {
return HadoopInputFile.create(path, hadoopConf.value());
}
+
+ @Override
+ public HadoopOutputFile newOutputFile(String path) throws IOException {
+ Objects.requireNonNull(path, "path can't be null");
+ return HadoopOutputFile.create(new Path(path), hadoopConf.value());
+ }
}
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopOutputFile.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopOutputFile.java
new file mode 100644
index 00000000000..3ed1146eefa
--- /dev/null
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopOutputFile.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.fileio.hadoop;
+
+import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+
+import java.io.IOException;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Implementation of {@link RapidsOutputFile} using the Hadoop file system.
+ *
+ * This class provides methods to create an output file and to obtain the absolute path.
+ */
+public class HadoopOutputFile implements RapidsOutputFile {
+ private final Path filePath;
+ private final FileSystem fs;
+
+ public static HadoopOutputFile create(Path filePath, Configuration conf)
+ throws IOException {
+ Objects.requireNonNull(filePath, "filePath can't be null");
+ Objects.requireNonNull(conf, "Hadoop conf can't be null");
+ FileSystem fs = filePath.getFileSystem(conf);
+ return new HadoopOutputFile(filePath, fs);
+ }
+
+ private HadoopOutputFile(Path filePath, FileSystem fs) {
+ Objects.requireNonNull(filePath, "filePath can't be null");
+ Objects.requireNonNull(fs, "FileSystem can't be null");
+ this.filePath = filePath;
+ this.fs = fs;
+ }
+
+ @Override
+ public HadoopOutputStream create(boolean overwrite) throws IOException {
+ FSDataOutputStream output = fs.create(filePath, overwrite);
+ return new HadoopOutputStream(output);
+ }
+
+ @Override
+ public String getPath() {
+ return filePath.toString();
+ }
+}
diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopOutputStream.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopOutputStream.java
new file mode 100644
index 00000000000..301570fba1f
--- /dev/null
+++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/fileio/hadoop/HadoopOutputStream.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.fileio.hadoop;
+
+import com.nvidia.spark.rapids.jni.fileio.RapidsOutputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * A {@link RapidsOutputStream} implementation that wraps a Hadoop {@link FSDataOutputStream}.
+ *
+ * This class delegates to the underlying output stream for write and close operations.
+ */
+public class HadoopOutputStream extends RapidsOutputStream {
+ private final FSDataOutputStream out;
+ private boolean closed;
+
+ public HadoopOutputStream(FSDataOutputStream out) {
+ this.out = requireNonNull(out, "out can't be null");
+ this.closed = false;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ out.write(b);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ out.write(b, off, len);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ out.flush();
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (!closed) {
+ out.close();
+ closed = true;
+ }
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala
index acd6e2d61bf..7b6bc084877 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala
@@ -26,6 +26,7 @@ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRestoreOnRetry, withRetry, withRetryNoSplit}
import com.nvidia.spark.rapids.io.async.{AsyncOutputStream, TrafficController}
+import com.nvidia.spark.rapids.jni.fileio.{RapidsFileIO, RapidsOutputFile}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
@@ -63,7 +64,8 @@ abstract class ColumnarOutputWriterFactory extends Serializable {
dataSchema: StructType,
context: TaskAttemptContext,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
- debugOutputPath: Option[String]): ColumnarOutputWriter
+ debugOutputPath: Option[String],
+ fileIO: RapidsFileIO): ColumnarOutputWriter
}
/**
@@ -78,7 +80,8 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
debugDumpPath: Option[String],
holdGpuBetweenBatches: Boolean = false,
- useAsyncWrite: Boolean = false) extends HostBufferConsumer with Logging {
+ useAsyncWrite: Boolean = false,
+ rapidsFileIO: RapidsFileIO) extends HostBufferConsumer with Logging {
// Length of the file written so far. This is used to track the size of the file
private var fileLength: Long = 0L
@@ -125,10 +128,8 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
private val trafficController: TrafficController = TrafficController.getWriteInstance
- private def openOutputStream(): OutputStream = {
- val hadoopPath = new Path(path)
- val fs = hadoopPath.getFileSystem(conf)
- fs.create(hadoopPath, false)
+ private def openOutputFile(): RapidsOutputFile = {
+ rapidsFileIO.newOutputFile(path())
}
// This is implemented as a method to make it easier to subclass
@@ -136,9 +137,9 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
protected def getOutputStream: OutputStream = {
if (useAsyncWrite) {
logWarning("Async output write enabled")
- AsyncOutputStream(() => openOutputStream(), trafficController, statsTrackers)
+ AsyncOutputStream(() => openOutputFile().create(false), trafficController, statsTrackers)
} else {
- openOutputStream()
+ openOutputFile().create(false)
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala
index e427163b085..0ca808551b8 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala
@@ -22,6 +22,7 @@ import ai.rapids.cudf._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
import com.nvidia.spark.rapids.jni.DateTimeRebase
+import com.nvidia.spark.rapids.jni.fileio.RapidsFileIO
import com.nvidia.spark.rapids.shims._
import com.nvidia.spark.rapids.shims.parquet._
import org.apache.hadoop.mapreduce.{Job, OutputCommitter, TaskAttemptContext}
@@ -283,11 +284,12 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
- statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
- debugOutputPath: Option[String]): ColumnarOutputWriter = {
+ statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
+ debugOutputPath: Option[String],
+ fileIO: RapidsFileIO): ColumnarOutputWriter = {
new GpuParquetWriter(path, dataSchema, compressionType, outputTimestampType.toString,
dateTimeRebaseMode, timestampRebaseMode, context, parquetFieldIdWriteEnabled,
- statsTrackers, debugOutputPath, holdGpuBetweenBatches, asyncOutputWriteEnabled)
+ statsTrackers, debugOutputPath, holdGpuBetweenBatches, asyncOutputWriteEnabled, fileIO)
}
override def getFileExtension(context: TaskAttemptContext): String = {
@@ -313,9 +315,10 @@ class GpuParquetWriter(
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
debugDumpPath: Option[String],
holdGpuBetweenBatches: Boolean,
- useAsyncWrite: Boolean)
+ useAsyncWrite: Boolean,
+ fileIO: RapidsFileIO)
extends ColumnarOutputWriter(context, dataSchema, "Parquet", true, statsTrackers,
- debugDumpPath, holdGpuBetweenBatches, useAsyncWrite) {
+ debugDumpPath, holdGpuBetweenBatches, useAsyncWrite, fileIO) {
override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = {
val cols = GpuColumnVector.extractBases(batch)
cols.foreach { col =>
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala
index ea0e2869a03..47a330bc365 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveFileFormat.scala
@@ -24,6 +24,7 @@ import com.google.common.base.Charsets
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.jni.CastStrings
+import com.nvidia.spark.rapids.jni.fileio.RapidsFileIO
import com.nvidia.spark.rapids.shims.BucketingUtilsShim
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -217,9 +218,10 @@ class GpuHiveParquetFileFormat(compType: CompressionType) extends ColumnarFileFo
dataSchema: StructType,
context: TaskAttemptContext,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
- debugOutputPath: Option[String]): ColumnarOutputWriter = {
+ debugOutputPath: Option[String],
+ fileIO: RapidsFileIO): ColumnarOutputWriter = {
new GpuHiveParquetWriter(path, dataSchema, context, compressionType, statsTrackers,
- debugOutputPath)
+ debugOutputPath, fileIO)
}
}
}
@@ -228,9 +230,10 @@ class GpuHiveParquetFileFormat(compType: CompressionType) extends ColumnarFileFo
class GpuHiveParquetWriter(override val path: String, dataSchema: StructType,
context: TaskAttemptContext, compType: CompressionType,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
- debugOutputPath: Option[String])
+ debugOutputPath: Option[String],
+ fileIO: RapidsFileIO)
extends ColumnarOutputWriter(context, dataSchema, "HiveParquet", true, statsTrackers,
- debugOutputPath) {
+ debugOutputPath, false, false, fileIO) {
override protected val tableWriter: CudfTableWriter = {
val optionsBuilder = SchemaUtils
@@ -260,8 +263,9 @@ class GpuHiveTextFileFormat extends ColumnarFileFormat with Logging with Seriali
dataSchema: StructType,
context: TaskAttemptContext,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
- debugOutputPath: Option[String]): ColumnarOutputWriter = {
- new GpuHiveTextWriter(path, dataSchema, context, statsTrackers, debugOutputPath)
+ debugOutputPath: Option[String],
+ fileIO: RapidsFileIO): ColumnarOutputWriter = {
+ new GpuHiveTextWriter(path, dataSchema, context, statsTrackers, debugOutputPath, fileIO)
}
}
}
@@ -271,9 +275,10 @@ class GpuHiveTextWriter(override val path: String,
dataSchema: StructType,
context: TaskAttemptContext,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
- debugOutputPath: Option[String])
+ debugOutputPath: Option[String],
+ fileIO: RapidsFileIO)
extends ColumnarOutputWriter(context, dataSchema, "HiveText", false, statsTrackers,
- debugOutputPath) {
+ debugOutputPath, false, false, fileIO) {
/**
* This reformats columns, to iron out inconsistencies between
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
index 8911a57e747..9ecdf3d1d1d 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
@@ -26,6 +26,7 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit
+import com.nvidia.spark.rapids.fileio.hadoop.HadoopFileIO
import com.nvidia.spark.rapids.shims.GpuFileFormatDataWriterShim
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.TaskAttemptContext
@@ -279,7 +280,8 @@ class GpuSingleDirectoryDataWriter(
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext,
statsTrackers = statsTrackers,
- debugOutputPath = debugOutputPath)
+ debugOutputPath = debugOutputPath,
+ fileIO = description.fileIO)
statsTrackers.foreach(_.newFile(currentPath))
}
@@ -623,7 +625,8 @@ class GpuDynamicPartitionDataSingleWriter(
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext,
statsTrackers = statsTrackers,
- debugOutputPath = debugOutputPath)
+ debugOutputPath = debugOutputPath,
+ description.fileIO)
statsTrackers.foreach(_.newFile(currentPath))
outWriter
@@ -984,6 +987,8 @@ class GpuWriteJobDescription(
val concurrentWriterPartitionFlushSize: Long)
extends Serializable {
+ lazy val fileIO: HadoopFileIO = new HadoopFileIO(serializableHadoopConf.value)
+
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
s"""
|All columns: ${allColumns.mkString(", ")}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala
index 7fa3cbdb9bb..69b14936eef 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala
@@ -20,6 +20,7 @@ import java.time.ZoneId
import ai.rapids.cudf._
import com.nvidia.spark.rapids._
+import com.nvidia.spark.rapids.jni.fileio.RapidsFileIO
import com.nvidia.spark.rapids.shims.OrcShims
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
@@ -195,9 +196,10 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging {
dataSchema: StructType,
context: TaskAttemptContext,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker],
- debugOutputPath: Option[String]): ColumnarOutputWriter = {
+ debugOutputPath: Option[String],
+ fileIO: RapidsFileIO): ColumnarOutputWriter = {
new GpuOrcWriter(path, dataSchema, context, statsTrackers, debugOutputPath,
- holdGpuBetweenBatches, orcStripeSizeRows, asyncOutputWriteEnabled)
+ holdGpuBetweenBatches, orcStripeSizeRows, asyncOutputWriteEnabled, fileIO)
}
override def getFileExtension(context: TaskAttemptContext): String = {
@@ -224,9 +226,10 @@ class GpuOrcWriter(
debugOutputPath: Option[String],
holdGpuBetweenBatches: Boolean,
orcStripeSizeRows: Option[Integer],
- useAsyncWrite: Boolean)
+ useAsyncWrite: Boolean,
+ fileIO: RapidsFileIO)
extends ColumnarOutputWriter(context, dataSchema, "ORC", true, statsTrackers, debugOutputPath,
- holdGpuBetweenBatches, useAsyncWrite) {
+ holdGpuBetweenBatches, useAsyncWrite, fileIO) {
override val tableWriter: TableWriter = {
val builder = SchemaUtils
diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala
index a1b0f04b51a..bf3e2924320 100644
--- a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala
+++ b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala
@@ -61,7 +61,10 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach {
rangeName,
includeRetry,
mockJobDescription.statsTrackers.map(_.newTaskInstance()),
- None) {
+ None,
+ false,
+ false,
+ mockJobDescription.fileIO) {
// this writer (for tests) doesn't do anything and passes through the
// batch passed to it when asked to transform, which is done to
@@ -95,7 +98,7 @@ class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach {
types,
"",
includeRetry))
- when(mockOutputWriterFactory.newInstance(any(), any(), any(), any(), any()))
+ when(mockOutputWriterFactory.newInstance(any(), any(), any(), any(), any(), any()))
.thenAnswer(_ => mockOutputWriter)
}