diff --git a/cpp/include/cudf/io/csv.hpp b/cpp/include/cudf/io/csv.hpp
index 1fc4114b94c..92b5447527c 100644
--- a/cpp/include/cudf/io/csv.hpp
+++ b/cpp/include/cudf/io/csv.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2022, NVIDIA CORPORATION.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -1332,7 +1332,7 @@ class csv_writer_options {
size_type _rows_per_chunk = std::numeric_limits::max();
// character to use for separating lines (default "\n")
std::string _line_terminator = "\n";
- // character to use for separating lines (default "\n")
+ // character to use for separating column values (default ",")
char _inter_column_delimiter = ',';
// string to use for values != 0 in INT8 types (default 'true')
std::string _true_value = std::string{"true"};
@@ -1422,9 +1422,9 @@ class csv_writer_options {
[[nodiscard]] std::string get_line_terminator() const { return _line_terminator; }
/**
- * @brief Returns character used for separating lines.
+ * @brief Returns character used for separating column values.
*
- * @return Character used for separating lines
+ * @return Character used for separating column values.
*/
[[nodiscard]] char get_inter_column_delimiter() const { return _inter_column_delimiter; }
@@ -1479,9 +1479,9 @@ class csv_writer_options {
void set_line_terminator(std::string term) { _line_terminator = term; }
/**
- * @brief Sets character used for separating lines.
+ * @brief Sets character used for separating column values.
*
- * @param delim Character to indicate delimiting
+ * @param delim Character to delimit column values
*/
void set_inter_column_delimiter(char delim) { _inter_column_delimiter = delim; }
@@ -1498,6 +1498,13 @@ class csv_writer_options {
* @param val String to represent values == 0 in INT8 types
*/
void set_false_value(std::string val) { _false_value = val; }
+
+ /**
+ * @brief (Re)sets the table being written.
+ *
+ * @param table Table to be written
+ */
+ void set_table(table_view const& table) { _table = table; }
};
/**
@@ -1586,9 +1593,9 @@ class csv_writer_options_builder {
}
/**
- * @brief Sets character used for separating lines.
+ * @brief Sets character used for separating column values.
*
- * @param delim Character to indicate delimiting
+ * @param delim Character to delimit column values
* @return this for chaining
*/
csv_writer_options_builder& inter_column_delimiter(char delim)
diff --git a/java/src/main/java/ai/rapids/cudf/CSVWriterOptions.java b/java/src/main/java/ai/rapids/cudf/CSVWriterOptions.java
new file mode 100644
index 00000000000..410eeab2b18
--- /dev/null
+++ b/java/src/main/java/ai/rapids/cudf/CSVWriterOptions.java
@@ -0,0 +1,134 @@
+/*
+ *
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package ai.rapids.cudf;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+public class CSVWriterOptions {
+
+ private String[] columnNames;
+ private Boolean includeHeader = false;
+ private String rowDelimiter = "\n";
+ private byte fieldDelimiter = ',';
+ private String nullValue = "";
+ private String falseValue = "false";
+ private String trueValue = "true";
+
+ private CSVWriterOptions(Builder builder) {
+ this.columnNames = builder.columnNames.toArray(new String[builder.columnNames.size()]);
+ this.nullValue = builder.nullValue;
+ this.includeHeader = builder.includeHeader;
+ this.fieldDelimiter = builder.fieldDelimiter;
+ this.rowDelimiter = builder.rowDelimiter;
+ this.falseValue = builder.falseValue;
+ this.trueValue = builder.trueValue;
+ }
+
+ public String[] getColumnNames() {
+ return columnNames;
+ }
+
+ public Boolean getIncludeHeader() {
+ return includeHeader;
+ }
+
+ public String getRowDelimiter() {
+ return rowDelimiter;
+ }
+
+ public byte getFieldDelimiter() {
+ return fieldDelimiter;
+ }
+
+ public String getNullValue() {
+ return nullValue;
+ }
+
+ public String getTrueValue() {
+ return trueValue;
+ }
+
+ public String getFalseValue() {
+ return falseValue;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ private List columnNames = Collections.emptyList();
+ private Boolean includeHeader = false;
+ private String rowDelimiter = "\n";
+ private byte fieldDelimiter = ',';
+ private String nullValue = "";
+ private String falseValue = "false";
+ private String trueValue = "true";
+
+ public CSVWriterOptions build() {
+ return new CSVWriterOptions(this);
+ }
+
+ public Builder withColumnNames(List columnNames) {
+ this.columnNames = columnNames;
+ return this;
+ }
+
+ public Builder withColumnNames(String... columnNames) {
+ List columnNamesList = new ArrayList<>();
+ for (String columnName : columnNames) {
+ columnNamesList.add(columnName);
+ }
+ return withColumnNames(columnNamesList);
+ }
+
+ public Builder withIncludeHeader(Boolean includeHeader) {
+ this.includeHeader = includeHeader;
+ return this;
+ }
+
+ public Builder withRowDelimiter(String rowDelimiter) {
+ this.rowDelimiter = rowDelimiter;
+ return this;
+ }
+
+ public Builder withFieldDelimiter(byte fieldDelimiter) {
+ this.fieldDelimiter = fieldDelimiter;
+ return this;
+ }
+
+ public Builder withNullValue(String nullValue) {
+ this.nullValue = nullValue;
+ return this;
+ }
+
+ public Builder withTrueValue(String trueValue) {
+ this.trueValue = trueValue;
+ return this;
+ }
+
+ public Builder withFalseValue(String falseValue) {
+ this.falseValue = falseValue;
+ return this;
+ }
+ }
+}
diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java
index b93352fa9ac..36dec194017 100644
--- a/java/src/main/java/ai/rapids/cudf/Table.java
+++ b/java/src/main/java/ai/rapids/cudf/Table.java
@@ -1,6 +1,6 @@
/*
*
- * Copyright (c) 2019-2022, NVIDIA CORPORATION.
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -857,6 +857,82 @@ public static Table readCSV(Schema schema, CSVOptions opts, HostMemoryBuffer buf
opts.getFalseValues()));
}
+ private static native void writeCSVToFile(long table,
+ String[] columnNames,
+ boolean includeHeader,
+ String rowDelimiter,
+ byte fieldDelimiter,
+ String nullValue,
+ String trueValue,
+ String falseValue,
+ String outputPath) throws CudfException;
+
+ public void writeCSVToFile(CSVWriterOptions options, String outputPath) {
+ writeCSVToFile(nativeHandle,
+ options.getColumnNames(),
+ options.getIncludeHeader(),
+ options.getRowDelimiter(),
+ options.getFieldDelimiter(),
+ options.getNullValue(),
+ options.getTrueValue(),
+ options.getFalseValue(),
+ outputPath);
+ }
+
+ private static native long startWriteCSVToBuffer(String[] columnNames,
+ boolean includeHeader,
+ String rowDelimiter,
+ byte fieldDelimiter,
+ String nullValue,
+ String trueValue,
+ String falseValue,
+ HostBufferConsumer buffer) throws CudfException;
+
+ private static native void writeCSVChunkToBuffer(long writerHandle, long tableHandle);
+
+ private static native void endWriteCSVToBuffer(long writerHandle);
+
+ private static class CSVTableWriter implements TableWriter {
+ private long writerHandle;
+ private HostBufferConsumer consumer;
+
+ private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer) {
+ this.writerHandle = startWriteCSVToBuffer(options.getColumnNames(),
+ options.getIncludeHeader(),
+ options.getRowDelimiter(),
+ options.getFieldDelimiter(),
+ options.getNullValue(),
+ options.getTrueValue(),
+ options.getFalseValue(),
+ consumer);
+ this.consumer = consumer;
+ }
+
+ @Override
+ public void write(Table table) {
+ if (writerHandle == 0) {
+ throw new IllegalStateException("Writer was already closed");
+ }
+ writeCSVChunkToBuffer(writerHandle, table.nativeHandle);
+ }
+
+ @Override
+ public void close() throws CudfException {
+ if (writerHandle != 0) {
+ endWriteCSVToBuffer(writerHandle);
+ writerHandle = 0;
+ }
+ if (consumer != null) {
+ consumer.done();
+ consumer = null;
+ }
+ }
+ }
+
+ public static TableWriter getCSVBufferWriter(CSVWriterOptions options, HostBufferConsumer bufferConsumer) {
+ return new CSVTableWriter(options, bufferConsumer);
+ }
+
/**
* Read a JSON file using the default JSONOptions.
* @param schema the schema of the file. You may use Schema.INFERRED to infer the schema.
diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp
index f417cdc597d..8740669db1f 100644
--- a/java/src/main/native/src/TableJni.cpp
+++ b/java/src/main/native/src/TableJni.cpp
@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
#include
#include
@@ -47,165 +46,17 @@
#include
#include
+#include "csv_chunked_writer.hpp"
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
#include "jni_compiled_expr.hpp"
#include "jni_utils.hpp"
+#include "jni_writer_data_sink.hpp"
#include "row_conversion.hpp"
namespace cudf {
namespace jni {
-constexpr long MINIMUM_WRITE_BUFFER_SIZE = 10 * 1024 * 1024; // 10 MB
-
-class jni_writer_data_sink final : public cudf::io::data_sink {
-public:
- explicit jni_writer_data_sink(JNIEnv *env, jobject callback) {
- if (env->GetJavaVM(&jvm) < 0) {
- throw std::runtime_error("GetJavaVM failed");
- }
-
- jclass cls = env->GetObjectClass(callback);
- if (cls == nullptr) {
- throw cudf::jni::jni_exception("class not found");
- }
-
- handle_buffer_method =
- env->GetMethodID(cls, "handleBuffer", "(Lai/rapids/cudf/HostMemoryBuffer;J)V");
- if (handle_buffer_method == nullptr) {
- throw cudf::jni::jni_exception("handleBuffer method");
- }
-
- this->callback = env->NewGlobalRef(callback);
- if (this->callback == nullptr) {
- throw cudf::jni::jni_exception("global ref");
- }
- }
-
- virtual ~jni_writer_data_sink() {
- // This should normally be called by a JVM thread. If the JVM environment is missing then this
- // is likely being triggered by the C++ runtime during shutdown. In that case the JVM may
- // already be destroyed and this thread should not try to attach to get an environment.
- JNIEnv *env = nullptr;
- if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) {
- env->DeleteGlobalRef(callback);
- if (current_buffer != nullptr) {
- env->DeleteGlobalRef(current_buffer);
- }
- }
- callback = nullptr;
- current_buffer = nullptr;
- }
-
- void host_write(void const *data, size_t size) override {
- JNIEnv *env = cudf::jni::get_jni_env(jvm);
- long left_to_copy = static_cast(size);
- const char *copy_from = static_cast(data);
- while (left_to_copy > 0) {
- long buffer_amount_available = current_buffer_len - current_buffer_written;
- if (buffer_amount_available <= 0) {
- // should never be < 0, but just to be safe
- rotate_buffer(env);
- buffer_amount_available = current_buffer_len - current_buffer_written;
- }
- long amount_to_copy =
- left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available;
- char *copy_to = current_buffer_data + current_buffer_written;
-
- std::memcpy(copy_to, copy_from, amount_to_copy);
- copy_from = copy_from + amount_to_copy;
- current_buffer_written += amount_to_copy;
- total_written += amount_to_copy;
- left_to_copy -= amount_to_copy;
- }
- }
-
- bool supports_device_write() const override { return true; }
-
- void device_write(void const *gpu_data, size_t size, rmm::cuda_stream_view stream) override {
- JNIEnv *env = cudf::jni::get_jni_env(jvm);
- long left_to_copy = static_cast(size);
- const char *copy_from = static_cast(gpu_data);
- while (left_to_copy > 0) {
- long buffer_amount_available = current_buffer_len - current_buffer_written;
- if (buffer_amount_available <= 0) {
- // should never be < 0, but just to be safe
- stream.synchronize();
- rotate_buffer(env);
- buffer_amount_available = current_buffer_len - current_buffer_written;
- }
- long amount_to_copy =
- left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available;
- char *copy_to = current_buffer_data + current_buffer_written;
-
- CUDF_CUDA_TRY(
- cudaMemcpyAsync(copy_to, copy_from, amount_to_copy, cudaMemcpyDefault, stream.value()));
-
- copy_from = copy_from + amount_to_copy;
- current_buffer_written += amount_to_copy;
- total_written += amount_to_copy;
- left_to_copy -= amount_to_copy;
- }
- stream.synchronize();
- }
-
- std::future device_write_async(void const *gpu_data, size_t size,
- rmm::cuda_stream_view stream) override {
- // Call the sync version until figuring out how to write asynchronously.
- device_write(gpu_data, size, stream);
- return std::async(std::launch::deferred, [] {});
- }
-
- void flush() override {
- if (current_buffer_written > 0) {
- JNIEnv *env = cudf::jni::get_jni_env(jvm);
- handle_buffer(env, current_buffer, current_buffer_written);
- if (current_buffer != nullptr) {
- env->DeleteGlobalRef(current_buffer);
- }
- current_buffer = nullptr;
- current_buffer_len = 0;
- current_buffer_data = nullptr;
- current_buffer_written = 0;
- }
- }
-
- size_t bytes_written() override { return total_written; }
-
- void set_alloc_size(long size) { this->alloc_size = size; }
-
-private:
- void rotate_buffer(JNIEnv *env) {
- if (current_buffer != nullptr) {
- handle_buffer(env, current_buffer, current_buffer_written);
- env->DeleteGlobalRef(current_buffer);
- current_buffer = nullptr;
- }
- jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true);
- current_buffer = env->NewGlobalRef(tmp_buffer);
- current_buffer_len = get_host_buffer_length(env, current_buffer);
- current_buffer_data = reinterpret_cast(get_host_buffer_address(env, current_buffer));
- current_buffer_written = 0;
- }
-
- void handle_buffer(JNIEnv *env, jobject buffer, jlong len) {
- env->CallVoidMethod(callback, handle_buffer_method, buffer, len);
- if (env->ExceptionCheck()) {
- throw std::runtime_error("handleBuffer threw an exception");
- }
- }
-
- JavaVM *jvm;
- jobject callback;
- jmethodID handle_buffer_method;
- jobject current_buffer = nullptr;
- char *current_buffer_data = nullptr;
- long current_buffer_len = 0;
- long current_buffer_written = 0;
- size_t total_written = 0;
- long alloc_size = MINIMUM_WRITE_BUFFER_SIZE;
-};
-
template class jni_table_writer_handle final {
public:
explicit jni_table_writer_handle(std::unique_ptr writer)
@@ -1349,6 +1200,118 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV(
CATCH_STD(env, NULL);
}
+JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeCSVToFile(
+ JNIEnv *env, jclass, jlong j_table_handle, jobjectArray j_column_names, jboolean include_header,
+ jstring j_row_delimiter, jbyte j_field_delimiter, jstring j_null_value, jstring j_true_value,
+ jstring j_false_value, jstring j_output_path) {
+ JNI_NULL_CHECK(env, j_table_handle, "table handle cannot be null.", );
+ JNI_NULL_CHECK(env, j_column_names, "column name array cannot be null", );
+ JNI_NULL_CHECK(env, j_row_delimiter, "row delimiter cannot be null", );
+ JNI_NULL_CHECK(env, j_field_delimiter, "field delimiter cannot be null", );
+ JNI_NULL_CHECK(env, j_null_value, "null representation string cannot be itself null", );
+ JNI_NULL_CHECK(env, j_true_value, "representation string for `true` cannot be null", );
+ JNI_NULL_CHECK(env, j_false_value, "representation string for `false` cannot be null", );
+ JNI_NULL_CHECK(env, j_output_path, "output path cannot be null", );
+
+ try {
+ cudf::jni::auto_set_device(env);
+
+ auto const native_output_path = cudf::jni::native_jstring{env, j_output_path};
+ auto const output_path = native_output_path.get();
+
+ auto const table = reinterpret_cast(j_table_handle);
+ auto const n_column_names = cudf::jni::native_jstringArray{env, j_column_names};
+ auto const column_names = n_column_names.as_cpp_vector();
+
+ auto const line_terminator = cudf::jni::native_jstring{env, j_row_delimiter};
+ auto const na_rep = cudf::jni::native_jstring{env, j_null_value};
+ auto const true_value = cudf::jni::native_jstring{env, j_true_value};
+ auto const false_value = cudf::jni::native_jstring{env, j_false_value};
+
+ auto options = cudf::io::csv_writer_options::builder(cudf::io::sink_info{output_path}, *table)
+ .names(column_names)
+ .include_header(static_cast(include_header))
+ .line_terminator(line_terminator.get())
+ .inter_column_delimiter(j_field_delimiter)
+ .na_rep(na_rep.get())
+ .true_value(true_value.get())
+ .false_value(false_value.get());
+
+ cudf::io::write_csv(options.build());
+ }
+ CATCH_STD(env, );
+}
+
+JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_startWriteCSVToBuffer(
+ JNIEnv *env, jclass, jobjectArray j_column_names, jboolean include_header,
+ jstring j_row_delimiter, jbyte j_field_delimiter, jstring j_null_value, jstring j_true_value,
+ jstring j_false_value, jobject j_buffer) {
+ JNI_NULL_CHECK(env, j_column_names, "column name array cannot be null", 0);
+ JNI_NULL_CHECK(env, j_row_delimiter, "row delimiter cannot be null", 0);
+ JNI_NULL_CHECK(env, j_field_delimiter, "field delimiter cannot be null", 0);
+ JNI_NULL_CHECK(env, j_null_value, "null representation string cannot be itself null", 0);
+ JNI_NULL_CHECK(env, j_buffer, "output buffer cannot be null", 0);
+
+ try {
+ cudf::jni::auto_set_device(env);
+
+ auto data_sink = std::make_unique(env, j_buffer);
+
+ auto const n_column_names = cudf::jni::native_jstringArray{env, j_column_names};
+ auto const column_names = n_column_names.as_cpp_vector();
+
+ auto const line_terminator = cudf::jni::native_jstring{env, j_row_delimiter};
+ auto const na_rep = cudf::jni::native_jstring{env, j_null_value};
+ auto const true_value = cudf::jni::native_jstring{env, j_true_value};
+ auto const false_value = cudf::jni::native_jstring{env, j_false_value};
+
+ auto options = cudf::io::csv_writer_options::builder(cudf::io::sink_info{data_sink.get()},
+ cudf::table_view{})
+ .names(column_names)
+ .include_header(static_cast(include_header))
+ .line_terminator(line_terminator.get())
+ .inter_column_delimiter(j_field_delimiter)
+ .na_rep(na_rep.get())
+ .true_value(true_value.get())
+ .false_value(false_value.get())
+ .build();
+
+ return ptr_as_jlong(new cudf::jni::io::csv_chunked_writer{options, data_sink});
+ }
+ CATCH_STD(env, 0);
+}
+
+JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeCSVChunkToBuffer(JNIEnv *env, jclass,
+ jlong j_writer_handle,
+ jlong j_table_handle) {
+ JNI_NULL_CHECK(env, j_writer_handle, "writer handle cannot be null.", );
+ JNI_NULL_CHECK(env, j_table_handle, "table handle cannot be null.", );
+
+ auto const table = reinterpret_cast(j_table_handle);
+ auto writer = reinterpret_cast(j_writer_handle);
+
+ try {
+ cudf::jni::auto_set_device(env);
+ writer->write(*table);
+ }
+ CATCH_STD(env, );
+}
+
+JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_endWriteCSVToBuffer(JNIEnv *env, jclass,
+ jlong j_writer_handle) {
+ JNI_NULL_CHECK(env, j_writer_handle, "writer handle cannot be null.", );
+
+ using cudf::jni::io::csv_chunked_writer;
+ auto writer =
+ std::unique_ptr{reinterpret_cast(j_writer_handle)};
+
+ try {
+ cudf::jni::auto_set_device(env);
+ writer->close();
+ }
+ CATCH_STD(env, );
+}
+
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readAndInferJSON(
JNIEnv *env, jclass, jlong buffer, jlong buffer_length, jboolean day_first, jboolean lines) {
diff --git a/java/src/main/native/src/csv_chunked_writer.hpp b/java/src/main/native/src/csv_chunked_writer.hpp
new file mode 100644
index 00000000000..1f1e73a1a4b
--- /dev/null
+++ b/java/src/main/native/src/csv_chunked_writer.hpp
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include
+
+#include
+
+#include "jni_writer_data_sink.hpp"
+
+namespace cudf::jni::io {
+
+/**
+ * @brief Class to write multiple Tables into the jni_writer_data_sink.
+ */
+class csv_chunked_writer {
+
+ cudf::io::csv_writer_options _options;
+ std::unique_ptr _sink;
+
+ bool _first_write_completed = false; ///< Decides if header should be written.
+
+public:
+ explicit csv_chunked_writer(cudf::io::csv_writer_options options,
+ std::unique_ptr &sink)
+ : _options{options}, _sink{std::move(sink)} {
+ auto const &sink_info = _options.get_sink();
+ // Assert invariants.
+ CUDF_EXPECTS(sink_info.type() != cudf::io::io_type::FILEPATH,
+ "Currently, chunked CSV writes to files is not supported.");
+
+ // Note: csv_writer_options ties the sink(s) to the options, and exposes
+ // no way to modify the sinks afterwards.
+ // Ideally, the options would have been separate from the tables written,
+ // and the destination sinks.
+ // Here, we retain a modifiable reference to the sink, and confirm the
+ // options point to the same sink.
+ CUDF_EXPECTS(sink_info.num_sinks() == 1, "csv_chunked_writer should have exactly one sink.");
+ CUDF_EXPECTS(sink_info.user_sinks()[0] == _sink.get(), "Sink mismatch.");
+ }
+
+ void write(cudf::table_view const &table) {
+ if (_first_write_completed) {
+ _options.enable_include_header(false); // Don't write header after the first write.
+ }
+
+ _options.set_table(table);
+ _options.set_rows_per_chunk(table.num_rows());
+
+ cudf::io::write_csv(_options);
+ _first_write_completed = true;
+ }
+
+ void close() {
+ // Flush pending writes to sink.
+ _sink->flush();
+ }
+};
+
+} // namespace cudf::jni::io
diff --git a/java/src/main/native/src/jni_writer_data_sink.hpp b/java/src/main/native/src/jni_writer_data_sink.hpp
new file mode 100644
index 00000000000..05fe594fcd5
--- /dev/null
+++ b/java/src/main/native/src/jni_writer_data_sink.hpp
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2023, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include
+
+#include "cudf_jni_apis.hpp"
+#include "jni_utils.hpp"
+
+namespace cudf::jni {
+
+constexpr long MINIMUM_WRITE_BUFFER_SIZE = 10 * 1024 * 1024; // 10 MB
+
+class jni_writer_data_sink final : public cudf::io::data_sink {
+public:
+ explicit jni_writer_data_sink(JNIEnv *env, jobject callback) {
+ if (env->GetJavaVM(&jvm) < 0) {
+ throw std::runtime_error("GetJavaVM failed");
+ }
+
+ jclass cls = env->GetObjectClass(callback);
+ if (cls == nullptr) {
+ throw cudf::jni::jni_exception("class not found");
+ }
+
+ handle_buffer_method =
+ env->GetMethodID(cls, "handleBuffer", "(Lai/rapids/cudf/HostMemoryBuffer;J)V");
+ if (handle_buffer_method == nullptr) {
+ throw cudf::jni::jni_exception("handleBuffer method");
+ }
+
+ this->callback = env->NewGlobalRef(callback);
+ if (this->callback == nullptr) {
+ throw cudf::jni::jni_exception("global ref");
+ }
+ }
+
+ virtual ~jni_writer_data_sink() {
+ // This should normally be called by a JVM thread. If the JVM environment is missing then this
+ // is likely being triggered by the C++ runtime during shutdown. In that case the JVM may
+ // already be destroyed and this thread should not try to attach to get an environment.
+ JNIEnv *env = nullptr;
+ if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) {
+ env->DeleteGlobalRef(callback);
+ if (current_buffer != nullptr) {
+ env->DeleteGlobalRef(current_buffer);
+ }
+ }
+ callback = nullptr;
+ current_buffer = nullptr;
+ }
+
+ void host_write(void const *data, size_t size) override {
+ JNIEnv *env = cudf::jni::get_jni_env(jvm);
+ long left_to_copy = static_cast(size);
+ const char *copy_from = static_cast(data);
+ while (left_to_copy > 0) {
+ long buffer_amount_available = current_buffer_len - current_buffer_written;
+ if (buffer_amount_available <= 0) {
+ // should never be < 0, but just to be safe
+ rotate_buffer(env);
+ buffer_amount_available = current_buffer_len - current_buffer_written;
+ }
+ long amount_to_copy =
+ left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available;
+ char *copy_to = current_buffer_data + current_buffer_written;
+
+ std::memcpy(copy_to, copy_from, amount_to_copy);
+ copy_from = copy_from + amount_to_copy;
+ current_buffer_written += amount_to_copy;
+ total_written += amount_to_copy;
+ left_to_copy -= amount_to_copy;
+ }
+ }
+
+ bool supports_device_write() const override { return true; }
+
+ void device_write(void const *gpu_data, size_t size, rmm::cuda_stream_view stream) override {
+ JNIEnv *env = cudf::jni::get_jni_env(jvm);
+ long left_to_copy = static_cast(size);
+ const char *copy_from = static_cast(gpu_data);
+ while (left_to_copy > 0) {
+ long buffer_amount_available = current_buffer_len - current_buffer_written;
+ if (buffer_amount_available <= 0) {
+ // should never be < 0, but just to be safe
+ stream.synchronize();
+ rotate_buffer(env);
+ buffer_amount_available = current_buffer_len - current_buffer_written;
+ }
+ long amount_to_copy =
+ left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available;
+ char *copy_to = current_buffer_data + current_buffer_written;
+
+ CUDF_CUDA_TRY(cudaMemcpyAsync(copy_to, copy_from, amount_to_copy, cudaMemcpyDeviceToHost,
+ stream.value()));
+
+ copy_from = copy_from + amount_to_copy;
+ current_buffer_written += amount_to_copy;
+ total_written += amount_to_copy;
+ left_to_copy -= amount_to_copy;
+ }
+ stream.synchronize();
+ }
+
+ std::future device_write_async(void const *gpu_data, size_t size,
+ rmm::cuda_stream_view stream) override {
+ // Call the sync version until figuring out how to write asynchronously.
+ device_write(gpu_data, size, stream);
+ return std::async(std::launch::deferred, [] {});
+ }
+
+ void flush() override {
+ if (current_buffer_written > 0) {
+ JNIEnv *env = cudf::jni::get_jni_env(jvm);
+ handle_buffer(env, current_buffer, current_buffer_written);
+ if (current_buffer != nullptr) {
+ env->DeleteGlobalRef(current_buffer);
+ }
+ current_buffer = nullptr;
+ current_buffer_len = 0;
+ current_buffer_data = nullptr;
+ current_buffer_written = 0;
+ }
+ }
+
+ size_t bytes_written() override { return total_written; }
+
+ void set_alloc_size(long size) { this->alloc_size = size; }
+
+private:
+ void rotate_buffer(JNIEnv *env) {
+ if (current_buffer != nullptr) {
+ handle_buffer(env, current_buffer, current_buffer_written);
+ env->DeleteGlobalRef(current_buffer);
+ current_buffer = nullptr;
+ }
+ jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true);
+ current_buffer = env->NewGlobalRef(tmp_buffer);
+ current_buffer_len = get_host_buffer_length(env, current_buffer);
+ current_buffer_data = reinterpret_cast(get_host_buffer_address(env, current_buffer));
+ current_buffer_written = 0;
+ }
+
+ void handle_buffer(JNIEnv *env, jobject buffer, jlong len) {
+ env->CallVoidMethod(callback, handle_buffer_method, buffer, len);
+ if (env->ExceptionCheck()) {
+ throw std::runtime_error("handleBuffer threw an exception");
+ }
+ }
+
+ JavaVM *jvm;
+ jobject callback;
+ jmethodID handle_buffer_method;
+ jobject current_buffer = nullptr;
+ char *current_buffer_data = nullptr;
+ long current_buffer_len = 0;
+ long current_buffer_written = 0;
+ size_t total_written = 0;
+ long alloc_size = MINIMUM_WRITE_BUFFER_SIZE;
+};
+
+} // namespace cudf::jni
diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java
index bf951a871e7..83e4cb536f3 100644
--- a/java/src/test/java/ai/rapids/cudf/TableTest.java
+++ b/java/src/test/java/ai/rapids/cudf/TableTest.java
@@ -1,6 +1,6 @@
/*
*
- * Copyright (c) 2019-2022, NVIDIA CORPORATION.
+ * Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -50,7 +50,6 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.*;
-import java.util.function.IntFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -575,6 +574,123 @@ void testReadCSV() {
}
}
+ private void testWriteCSVToFileImpl(char fieldDelim, boolean includeHeader,
+ String trueValue, String falseValue) throws IOException {
+ File outputFile = File.createTempFile("testWriteCSVToFile", ".csv");
+ Schema schema = Schema.builder()
+ .column(DType.INT32, "i")
+ .column(DType.FLOAT64, "f")
+ .column(DType.BOOL8, "b")
+ .column(DType.STRING, "str")
+ .build();
+ CSVWriterOptions writeOptions = CSVWriterOptions.builder()
+ .withColumnNames(schema.getColumnNames())
+ .withIncludeHeader(false)
+ .withFieldDelimiter((byte)'\u0001')
+ .withRowDelimiter("\n")
+ .withNullValue("\\N")
+ .withTrueValue("T")
+ .withFalseValue("F")
+ .build();
+ try (Table inputTable
+ = new Table.TestBuilder()
+ .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+ .column(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)
+ .column(false, true, false, true, false, true, false, true, false, true)
+ .column("All", "the", "leaves", "are", "brown", "and", "the", "sky", "is", "grey")
+ .build()) {
+ inputTable.writeCSVToFile(writeOptions, outputFile.getAbsolutePath());
+
+ // Read back.
+ CSVOptions readOptions = CSVOptions.builder()
+ .includeColumn("i")
+ .includeColumn("f")
+ .includeColumn("b")
+ .includeColumn("str")
+ .hasHeader(false)
+ .withDelim('\u0001')
+ .withTrueValue("T")
+ .withFalseValue("F")
+ .build();
+ try (Table readTable = Table.readCSV(schema, readOptions, outputFile)) {
+ assertTablesAreEqual(inputTable, readTable);
+ }
+ } finally {
+ outputFile.delete();
+ }
+ }
+
+ @Test
+ void testWriteCSVToFile() throws IOException {
+ final boolean INCLUDE_HEADER = true;
+ final boolean NO_HEADER = false;
+ testWriteCSVToFileImpl(',', INCLUDE_HEADER, "true", "false");
+ testWriteCSVToFileImpl(',', NO_HEADER, "TRUE", "FALSE");
+ testWriteCSVToFileImpl('\u0001', INCLUDE_HEADER, "T", "F");
+ testWriteCSVToFileImpl('\u0001', NO_HEADER, "True", "False");
+ }
+
+ private void testChunkedCSVWriterImpl(char fieldDelim, boolean includeHeader,
+ String trueValue, String falseValue) throws IOException {
+ Schema schema = Schema.builder()
+ .column(DType.INT32, "i")
+ .column(DType.FLOAT64, "f")
+ .column(DType.BOOL8, "b")
+ .column(DType.STRING, "str")
+ .build();
+ CSVWriterOptions writeOptions = CSVWriterOptions.builder()
+ .withColumnNames(schema.getColumnNames())
+ .withIncludeHeader(includeHeader)
+ .withFieldDelimiter((byte)fieldDelim)
+ .withRowDelimiter("\n")
+ .withNullValue("\\N")
+ .withTrueValue(trueValue)
+ .withFalseValue(falseValue)
+ .build();
+ try (Table inputTable
+ = new Table.TestBuilder()
+ .column(0, 1, 2, 3, 4, 5, 6, 7, 8, null)
+ .column(0.0, 1.0, 2.0, 3.0, 4.0, null, 6.0, 7.0, 8.0, 9.0)
+ .column(false, true, null, true, false, true, null, true, false, true)
+ .column("All", "the", "leaves", "are", "brown", "and", "the", "sky", "is", null)
+ .build();
+ MyBufferConsumer consumer = new MyBufferConsumer()) {
+
+ try (TableWriter writer = Table.getCSVBufferWriter(writeOptions, consumer)) {
+ writer.write(inputTable);
+ writer.write(inputTable);
+ writer.write(inputTable);
+ }
+
+ // Read back.
+ CSVOptions readOptions = CSVOptions.builder()
+ .includeColumn("i")
+ .includeColumn("f")
+ .includeColumn("b")
+ .includeColumn("str")
+ .hasHeader(includeHeader)
+ .withDelim(fieldDelim)
+ .withNullValue("\\N")
+ .withTrueValue(trueValue)
+ .withFalseValue(falseValue)
+ .build();
+ try (Table readTable = Table.readCSV(schema, readOptions, consumer.buffer, 0, consumer.offset);
+ Table expected = Table.concatenate(inputTable, inputTable, inputTable)) {
+ assertTablesAreEqual(expected, readTable);
+ }
+ }
+ }
+
+ @Test
+ void testChunkedCSVWriter() throws IOException {
+ final boolean INCLUDE_HEADER = true;
+ final boolean NO_HEADER = false;
+ testChunkedCSVWriterImpl(',', NO_HEADER, "true", "false");
+ testChunkedCSVWriterImpl(',', INCLUDE_HEADER, "TRUE", "FALSE");
+ testChunkedCSVWriterImpl('\u0001', NO_HEADER, "T", "F");
+ testChunkedCSVWriterImpl('\u0001', INCLUDE_HEADER, "True", "False");
+ }
+
@Test
void testReadParquet() {
ParquetOptions opts = ParquetOptions.builder()