Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.nvidia.spark.rapids.jni.fileio.RapidsFileIO;
import com.nvidia.spark.rapids.jni.fileio.RapidsInputFile;
import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
import org.apache.iceberg.io.FileIO;

import java.io.IOException;
Expand Down Expand Up @@ -48,4 +49,9 @@ public IcebergFileIO(FileIO delegate) {
public IcebergInputFile newInputFile(String path) throws IOException {
return new IcebergInputFile(delegate.newInputFile(path));
}

@Override
public IcebergOutputFile newOutputFile(String path) throws IOException {
return new IcebergOutputFile(delegate.newOutputFile(path));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.fileio.iceberg;

import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
import org.apache.iceberg.io.OutputFile;

import java.io.IOException;
import java.util.Objects;

/**
* Implementation of {@link RapidsOutputFile} using Iceberg {@link OutputFile}.
*/
public class IcebergOutputFile implements RapidsOutputFile {
private final OutputFile delegate;

public IcebergOutputFile(OutputFile delegate) {
Objects.requireNonNull(delegate, "delegate can't be null");
this.delegate = delegate;
}

@Override
public IcebergOutputStream create(boolean overwrite) throws IOException {
if (overwrite) {
return new IcebergOutputStream(delegate.createOrOverwrite());
}
return new IcebergOutputStream(delegate.create());
}

@Override
public String getPath() {
return delegate.location();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.fileio.iceberg;

import com.nvidia.spark.rapids.jni.fileio.RapidsOutputStream;
import org.apache.iceberg.io.PositionOutputStream;

import java.io.IOException;
import java.io.OutputStream;

import static java.util.Objects.requireNonNull;

/**
* A {@link RapidsOutputStream} implementation that wraps an Iceberg {@link PositionOutputStream}.
*/
public class IcebergOutputStream extends RapidsOutputStream {
private final PositionOutputStream out;
private boolean closed;

public IcebergOutputStream(PositionOutputStream out) {
this.out = requireNonNull(out, "out can't be null");
this.closed = false;
}

@Override
public void write(int b) throws IOException {
out.write(b);
}

@Override
public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len);
}

@Override
public void flush() throws IOException {
out.flush();
}

@Override
public void close() throws IOException {
if (!closed) {
out.close();
closed = true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class GpuSparkFileWriterFactory(val table: Table,
val columnarOutputWriterFactory: ColumnarOutputWriterFactory,
val taskStatsTracker: ColumnarWriteTaskStatsTracker,
val hadoopConf: SerializableConfiguration,
val fileIO: IcebergFileIO
) extends FileWriterFactory[SpillableColumnarBatch] {
require(dataFileFormat == FileFormat.PARQUET,
s"GpuSparkFileWriterFactory only supports PARQUET file format, but got $dataFileFormat")
Expand Down Expand Up @@ -78,8 +79,8 @@ class GpuSparkFileWriterFactory(val table: Table,
dataSchema = dataSparkType,
context = taskAttemptContext,
statsTrackers = Seq(taskStatsTracker),
debugOutputPath = None
).asInstanceOf[GpuParquetWriter]
debugOutputPath = None,
fileIO).asInstanceOf[GpuParquetWriter]

new GpuIcebergParquetAppender(
gpuWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ import scala.util.{Failure, Success}
import com.nvidia.spark.rapids.{ColumnarOutputWriterFactory, GpuParquetFileFormat, GpuWrite, SparkPlanMeta, SpillableColumnarBatch}
import com.nvidia.spark.rapids.Arm.closeOnExcept
import com.nvidia.spark.rapids.SpillPriorities.ACTIVE_ON_DECK_PRIORITY
import com.nvidia.spark.rapids.fileio.iceberg.IcebergFileIO
import com.nvidia.spark.rapids.iceberg.GpuIcebergPartitioner
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.shaded.org.apache.commons.lang3.reflect.{FieldUtils, MethodUtils}
import org.apache.iceberg.{DataFile, FileFormat, PartitionSpec, Schema, SerializableTable, SnapshotUpdate, Table}
import org.apache.iceberg.io.{DataWriteResult, FileIO, GpuClusteredDataWriter, GpuFanoutDataWriter, GpuRollingDataWriter, OutputFileFactory, PartitioningWriter}
Expand Down Expand Up @@ -96,7 +95,6 @@ class GpuSparkWrite(cpu: SparkWrite) extends GpuWrite with RequiresDistributionA
val tmpJob = Job.getInstance(hadoopConf)
tmpJob.setOutputKeyClass(classOf[Void])
tmpJob.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(tmpJob, new Path(table.location()))
tmpJob
}

Expand Down Expand Up @@ -202,6 +200,8 @@ class GpuWriterFactory(val tableBroadcast: Broadcast[Table],
val hadoopConf: SerializableConfiguration
) extends DataWriterFactory {

private lazy val fileIO: IcebergFileIO = new IcebergFileIO(tableBroadcast.value.io())

override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
val table = tableBroadcast.value
val spec = table.specs().get(outputSpecId)
Expand All @@ -221,7 +221,8 @@ class GpuWriterFactory(val tableBroadcast: Broadcast[Table],
format,
outputWriterFactory,
statsTracker.newTaskInstance(),
hadoopConf)
hadoopConf,
fileIO)

if (spec.isUnpartitioned) {
new GpuUnpartitionedDataWriter(writerFactory, outputFileFactory, io, spec, targetFileSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@
from marks import iceberg, ignore_order, allow_non_gpu
from spark_session import with_gpu_session, with_cpu_session, is_spark_35x

pytestmark = [
pytest.mark.skipif(not is_spark_35x(),
reason="Current spark-rapids only support spark 3.5.x"),
pytest.mark.skipif(is_iceberg_remote_catalog(),
reason="https://github.com/NVIDIA/spark-rapids/issues/13471")
]
pytestmark = pytest.mark.skipif(not is_spark_35x(),
reason="Current spark-rapids only support spark 3.5.x")


def do_test_insert_into_table_sql(spark_tmp_table_factory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import com.nvidia.spark.rapids.jni.fileio.RapidsFileIO;
import com.nvidia.spark.rapids.jni.fileio.RapidsInputFile;
import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.util.SerializableConfiguration;

Expand All @@ -38,12 +40,18 @@ public HadoopFileIO(Configuration hadoopConf) {
}

@Override
public RapidsInputFile newInputFile(String path) throws IOException {
public HadoopInputFile newInputFile(String path) throws IOException {
return this.newInputFile(new Path(path));
}

@Override
public RapidsInputFile newInputFile(Path path) throws IOException {
public HadoopInputFile newInputFile(Path path) throws IOException {
return HadoopInputFile.create(path, hadoopConf.value());
}

@Override
public HadoopOutputFile newOutputFile(String path) throws IOException {
Objects.requireNonNull(path, "path can't be null");
return HadoopOutputFile.create(new Path(path), hadoopConf.value());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.fileio.hadoop;

import com.nvidia.spark.rapids.jni.fileio.RapidsOutputFile;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;

import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
* Implementation of {@link RapidsOutputFile} using the Hadoop file system.
* <br/>
* This class provides methods to create an output file and to obtain the absolute path.
*/
public class HadoopOutputFile implements RapidsOutputFile {
private final Path filePath;
private final FileSystem fs;

public static HadoopOutputFile create(Path filePath, Configuration conf)
throws IOException {
Objects.requireNonNull(filePath, "filePath can't be null");
Objects.requireNonNull(conf, "Hadoop conf can't be null");
FileSystem fs = filePath.getFileSystem(conf);
return new HadoopOutputFile(filePath, fs);
}

private HadoopOutputFile(Path filePath, FileSystem fs) {
Objects.requireNonNull(filePath, "filePath can't be null");
Objects.requireNonNull(fs, "FileSystem can't be null");
this.filePath = filePath;
this.fs = fs;
}

@Override
public HadoopOutputStream create(boolean overwrite) throws IOException {
FSDataOutputStream output = fs.create(filePath, overwrite);
return new HadoopOutputStream(output);
}

@Override
public String getPath() {
return filePath.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.fileio.hadoop;

import com.nvidia.spark.rapids.jni.fileio.RapidsOutputStream;
import org.apache.hadoop.fs.FSDataOutputStream;

import java.io.IOException;
import java.io.OutputStream;

import static java.util.Objects.requireNonNull;

/**
* A {@link RapidsOutputStream} implementation that wraps a Hadoop {@link FSDataOutputStream}.
* <br/>
* This class delegates to the underlying output stream for write and close operations.
*/
public class HadoopOutputStream extends RapidsOutputStream {
private final FSDataOutputStream out;
private boolean closed;

public HadoopOutputStream(FSDataOutputStream out) {
this.out = requireNonNull(out, "out can't be null");
this.closed = false;
}

@Override
public void write(int b) throws IOException {
out.write(b);
}

@Override
public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len);
}

@Override
public void flush() throws IOException {
out.flush();
}

@Override
public void close() throws IOException {
if (!closed) {
out.close();
closed = true;
}
}
}
Loading