diff --git a/cpp/include/cudf/io/csv.hpp b/cpp/include/cudf/io/csv.hpp index 1fc4114b94c..92b5447527c 100644 --- a/cpp/include/cudf/io/csv.hpp +++ b/cpp/include/cudf/io/csv.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1332,7 +1332,7 @@ class csv_writer_options { size_type _rows_per_chunk = std::numeric_limits::max(); // character to use for separating lines (default "\n") std::string _line_terminator = "\n"; - // character to use for separating lines (default "\n") + // character to use for separating column values (default ",") char _inter_column_delimiter = ','; // string to use for values != 0 in INT8 types (default 'true') std::string _true_value = std::string{"true"}; @@ -1422,9 +1422,9 @@ class csv_writer_options { [[nodiscard]] std::string get_line_terminator() const { return _line_terminator; } /** - * @brief Returns character used for separating lines. + * @brief Returns character used for separating column values. * - * @return Character used for separating lines + * @return Character used for separating column values. */ [[nodiscard]] char get_inter_column_delimiter() const { return _inter_column_delimiter; } @@ -1479,9 +1479,9 @@ class csv_writer_options { void set_line_terminator(std::string term) { _line_terminator = term; } /** - * @brief Sets character used for separating lines. + * @brief Sets character used for separating column values. * - * @param delim Character to indicate delimiting + * @param delim Character to delimit column values */ void set_inter_column_delimiter(char delim) { _inter_column_delimiter = delim; } @@ -1498,6 +1498,13 @@ class csv_writer_options { * @param val String to represent values == 0 in INT8 types */ void set_false_value(std::string val) { _false_value = val; } + + /** + * @brief (Re)sets the table being written. + * + * @param table Table to be written + */ + void set_table(table_view const& table) { _table = table; } }; /** @@ -1586,9 +1593,9 @@ class csv_writer_options_builder { } /** - * @brief Sets character used for separating lines. + * @brief Sets character used for separating column values. * - * @param delim Character to indicate delimiting + * @param delim Character to delimit column values * @return this for chaining */ csv_writer_options_builder& inter_column_delimiter(char delim) diff --git a/java/src/main/java/ai/rapids/cudf/CSVWriterOptions.java b/java/src/main/java/ai/rapids/cudf/CSVWriterOptions.java new file mode 100644 index 00000000000..410eeab2b18 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/CSVWriterOptions.java @@ -0,0 +1,134 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class CSVWriterOptions { + + private String[] columnNames; + private Boolean includeHeader = false; + private String rowDelimiter = "\n"; + private byte fieldDelimiter = ','; + private String nullValue = ""; + private String falseValue = "false"; + private String trueValue = "true"; + + private CSVWriterOptions(Builder builder) { + this.columnNames = builder.columnNames.toArray(new String[builder.columnNames.size()]); + this.nullValue = builder.nullValue; + this.includeHeader = builder.includeHeader; + this.fieldDelimiter = builder.fieldDelimiter; + this.rowDelimiter = builder.rowDelimiter; + this.falseValue = builder.falseValue; + this.trueValue = builder.trueValue; + } + + public String[] getColumnNames() { + return columnNames; + } + + public Boolean getIncludeHeader() { + return includeHeader; + } + + public String getRowDelimiter() { + return rowDelimiter; + } + + public byte getFieldDelimiter() { + return fieldDelimiter; + } + + public String getNullValue() { + return nullValue; + } + + public String getTrueValue() { + return trueValue; + } + + public String getFalseValue() { + return falseValue; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private List columnNames = Collections.emptyList(); + private Boolean includeHeader = false; + private String rowDelimiter = "\n"; + private byte fieldDelimiter = ','; + private String nullValue = ""; + private String falseValue = "false"; + private String trueValue = "true"; + + public CSVWriterOptions build() { + return new CSVWriterOptions(this); + } + + public Builder withColumnNames(List columnNames) { + this.columnNames = columnNames; + return this; + } + + public Builder withColumnNames(String... columnNames) { + List columnNamesList = new ArrayList<>(); + for (String columnName : columnNames) { + columnNamesList.add(columnName); + } + return withColumnNames(columnNamesList); + } + + public Builder withIncludeHeader(Boolean includeHeader) { + this.includeHeader = includeHeader; + return this; + } + + public Builder withRowDelimiter(String rowDelimiter) { + this.rowDelimiter = rowDelimiter; + return this; + } + + public Builder withFieldDelimiter(byte fieldDelimiter) { + this.fieldDelimiter = fieldDelimiter; + return this; + } + + public Builder withNullValue(String nullValue) { + this.nullValue = nullValue; + return this; + } + + public Builder withTrueValue(String trueValue) { + this.trueValue = trueValue; + return this; + } + + public Builder withFalseValue(String falseValue) { + this.falseValue = falseValue; + return this; + } + } +} diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index b93352fa9ac..36dec194017 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -857,6 +857,82 @@ public static Table readCSV(Schema schema, CSVOptions opts, HostMemoryBuffer buf opts.getFalseValues())); } + private static native void writeCSVToFile(long table, + String[] columnNames, + boolean includeHeader, + String rowDelimiter, + byte fieldDelimiter, + String nullValue, + String trueValue, + String falseValue, + String outputPath) throws CudfException; + + public void writeCSVToFile(CSVWriterOptions options, String outputPath) { + writeCSVToFile(nativeHandle, + options.getColumnNames(), + options.getIncludeHeader(), + options.getRowDelimiter(), + options.getFieldDelimiter(), + options.getNullValue(), + options.getTrueValue(), + options.getFalseValue(), + outputPath); + } + + private static native long startWriteCSVToBuffer(String[] columnNames, + boolean includeHeader, + String rowDelimiter, + byte fieldDelimiter, + String nullValue, + String trueValue, + String falseValue, + HostBufferConsumer buffer) throws CudfException; + + private static native void writeCSVChunkToBuffer(long writerHandle, long tableHandle); + + private static native void endWriteCSVToBuffer(long writerHandle); + + private static class CSVTableWriter implements TableWriter { + private long writerHandle; + private HostBufferConsumer consumer; + + private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer) { + this.writerHandle = startWriteCSVToBuffer(options.getColumnNames(), + options.getIncludeHeader(), + options.getRowDelimiter(), + options.getFieldDelimiter(), + options.getNullValue(), + options.getTrueValue(), + options.getFalseValue(), + consumer); + this.consumer = consumer; + } + + @Override + public void write(Table table) { + if (writerHandle == 0) { + throw new IllegalStateException("Writer was already closed"); + } + writeCSVChunkToBuffer(writerHandle, table.nativeHandle); + } + + @Override + public void close() throws CudfException { + if (writerHandle != 0) { + endWriteCSVToBuffer(writerHandle); + writerHandle = 0; + } + if (consumer != null) { + consumer.done(); + consumer = null; + } + } + } + + public static TableWriter getCSVBufferWriter(CSVWriterOptions options, HostBufferConsumer bufferConsumer) { + return new CSVTableWriter(options, bufferConsumer); + } + /** * Read a JSON file using the default JSONOptions. * @param schema the schema of the file. You may use Schema.INFERRED to infer the schema. diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index f417cdc597d..8740669db1f 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include #include @@ -47,165 +46,17 @@ #include #include +#include "csv_chunked_writer.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" #include "jni_compiled_expr.hpp" #include "jni_utils.hpp" +#include "jni_writer_data_sink.hpp" #include "row_conversion.hpp" namespace cudf { namespace jni { -constexpr long MINIMUM_WRITE_BUFFER_SIZE = 10 * 1024 * 1024; // 10 MB - -class jni_writer_data_sink final : public cudf::io::data_sink { -public: - explicit jni_writer_data_sink(JNIEnv *env, jobject callback) { - if (env->GetJavaVM(&jvm) < 0) { - throw std::runtime_error("GetJavaVM failed"); - } - - jclass cls = env->GetObjectClass(callback); - if (cls == nullptr) { - throw cudf::jni::jni_exception("class not found"); - } - - handle_buffer_method = - env->GetMethodID(cls, "handleBuffer", "(Lai/rapids/cudf/HostMemoryBuffer;J)V"); - if (handle_buffer_method == nullptr) { - throw cudf::jni::jni_exception("handleBuffer method"); - } - - this->callback = env->NewGlobalRef(callback); - if (this->callback == nullptr) { - throw cudf::jni::jni_exception("global ref"); - } - } - - virtual ~jni_writer_data_sink() { - // This should normally be called by a JVM thread. If the JVM environment is missing then this - // is likely being triggered by the C++ runtime during shutdown. In that case the JVM may - // already be destroyed and this thread should not try to attach to get an environment. - JNIEnv *env = nullptr; - if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { - env->DeleteGlobalRef(callback); - if (current_buffer != nullptr) { - env->DeleteGlobalRef(current_buffer); - } - } - callback = nullptr; - current_buffer = nullptr; - } - - void host_write(void const *data, size_t size) override { - JNIEnv *env = cudf::jni::get_jni_env(jvm); - long left_to_copy = static_cast(size); - const char *copy_from = static_cast(data); - while (left_to_copy > 0) { - long buffer_amount_available = current_buffer_len - current_buffer_written; - if (buffer_amount_available <= 0) { - // should never be < 0, but just to be safe - rotate_buffer(env); - buffer_amount_available = current_buffer_len - current_buffer_written; - } - long amount_to_copy = - left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available; - char *copy_to = current_buffer_data + current_buffer_written; - - std::memcpy(copy_to, copy_from, amount_to_copy); - copy_from = copy_from + amount_to_copy; - current_buffer_written += amount_to_copy; - total_written += amount_to_copy; - left_to_copy -= amount_to_copy; - } - } - - bool supports_device_write() const override { return true; } - - void device_write(void const *gpu_data, size_t size, rmm::cuda_stream_view stream) override { - JNIEnv *env = cudf::jni::get_jni_env(jvm); - long left_to_copy = static_cast(size); - const char *copy_from = static_cast(gpu_data); - while (left_to_copy > 0) { - long buffer_amount_available = current_buffer_len - current_buffer_written; - if (buffer_amount_available <= 0) { - // should never be < 0, but just to be safe - stream.synchronize(); - rotate_buffer(env); - buffer_amount_available = current_buffer_len - current_buffer_written; - } - long amount_to_copy = - left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available; - char *copy_to = current_buffer_data + current_buffer_written; - - CUDF_CUDA_TRY( - cudaMemcpyAsync(copy_to, copy_from, amount_to_copy, cudaMemcpyDefault, stream.value())); - - copy_from = copy_from + amount_to_copy; - current_buffer_written += amount_to_copy; - total_written += amount_to_copy; - left_to_copy -= amount_to_copy; - } - stream.synchronize(); - } - - std::future device_write_async(void const *gpu_data, size_t size, - rmm::cuda_stream_view stream) override { - // Call the sync version until figuring out how to write asynchronously. - device_write(gpu_data, size, stream); - return std::async(std::launch::deferred, [] {}); - } - - void flush() override { - if (current_buffer_written > 0) { - JNIEnv *env = cudf::jni::get_jni_env(jvm); - handle_buffer(env, current_buffer, current_buffer_written); - if (current_buffer != nullptr) { - env->DeleteGlobalRef(current_buffer); - } - current_buffer = nullptr; - current_buffer_len = 0; - current_buffer_data = nullptr; - current_buffer_written = 0; - } - } - - size_t bytes_written() override { return total_written; } - - void set_alloc_size(long size) { this->alloc_size = size; } - -private: - void rotate_buffer(JNIEnv *env) { - if (current_buffer != nullptr) { - handle_buffer(env, current_buffer, current_buffer_written); - env->DeleteGlobalRef(current_buffer); - current_buffer = nullptr; - } - jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true); - current_buffer = env->NewGlobalRef(tmp_buffer); - current_buffer_len = get_host_buffer_length(env, current_buffer); - current_buffer_data = reinterpret_cast(get_host_buffer_address(env, current_buffer)); - current_buffer_written = 0; - } - - void handle_buffer(JNIEnv *env, jobject buffer, jlong len) { - env->CallVoidMethod(callback, handle_buffer_method, buffer, len); - if (env->ExceptionCheck()) { - throw std::runtime_error("handleBuffer threw an exception"); - } - } - - JavaVM *jvm; - jobject callback; - jmethodID handle_buffer_method; - jobject current_buffer = nullptr; - char *current_buffer_data = nullptr; - long current_buffer_len = 0; - long current_buffer_written = 0; - size_t total_written = 0; - long alloc_size = MINIMUM_WRITE_BUFFER_SIZE; -}; - template class jni_table_writer_handle final { public: explicit jni_table_writer_handle(std::unique_ptr writer) @@ -1349,6 +1200,118 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( CATCH_STD(env, NULL); } +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeCSVToFile( + JNIEnv *env, jclass, jlong j_table_handle, jobjectArray j_column_names, jboolean include_header, + jstring j_row_delimiter, jbyte j_field_delimiter, jstring j_null_value, jstring j_true_value, + jstring j_false_value, jstring j_output_path) { + JNI_NULL_CHECK(env, j_table_handle, "table handle cannot be null.", ); + JNI_NULL_CHECK(env, j_column_names, "column name array cannot be null", ); + JNI_NULL_CHECK(env, j_row_delimiter, "row delimiter cannot be null", ); + JNI_NULL_CHECK(env, j_field_delimiter, "field delimiter cannot be null", ); + JNI_NULL_CHECK(env, j_null_value, "null representation string cannot be itself null", ); + JNI_NULL_CHECK(env, j_true_value, "representation string for `true` cannot be null", ); + JNI_NULL_CHECK(env, j_false_value, "representation string for `false` cannot be null", ); + JNI_NULL_CHECK(env, j_output_path, "output path cannot be null", ); + + try { + cudf::jni::auto_set_device(env); + + auto const native_output_path = cudf::jni::native_jstring{env, j_output_path}; + auto const output_path = native_output_path.get(); + + auto const table = reinterpret_cast(j_table_handle); + auto const n_column_names = cudf::jni::native_jstringArray{env, j_column_names}; + auto const column_names = n_column_names.as_cpp_vector(); + + auto const line_terminator = cudf::jni::native_jstring{env, j_row_delimiter}; + auto const na_rep = cudf::jni::native_jstring{env, j_null_value}; + auto const true_value = cudf::jni::native_jstring{env, j_true_value}; + auto const false_value = cudf::jni::native_jstring{env, j_false_value}; + + auto options = cudf::io::csv_writer_options::builder(cudf::io::sink_info{output_path}, *table) + .names(column_names) + .include_header(static_cast(include_header)) + .line_terminator(line_terminator.get()) + .inter_column_delimiter(j_field_delimiter) + .na_rep(na_rep.get()) + .true_value(true_value.get()) + .false_value(false_value.get()); + + cudf::io::write_csv(options.build()); + } + CATCH_STD(env, ); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_startWriteCSVToBuffer( + JNIEnv *env, jclass, jobjectArray j_column_names, jboolean include_header, + jstring j_row_delimiter, jbyte j_field_delimiter, jstring j_null_value, jstring j_true_value, + jstring j_false_value, jobject j_buffer) { + JNI_NULL_CHECK(env, j_column_names, "column name array cannot be null", 0); + JNI_NULL_CHECK(env, j_row_delimiter, "row delimiter cannot be null", 0); + JNI_NULL_CHECK(env, j_field_delimiter, "field delimiter cannot be null", 0); + JNI_NULL_CHECK(env, j_null_value, "null representation string cannot be itself null", 0); + JNI_NULL_CHECK(env, j_buffer, "output buffer cannot be null", 0); + + try { + cudf::jni::auto_set_device(env); + + auto data_sink = std::make_unique(env, j_buffer); + + auto const n_column_names = cudf::jni::native_jstringArray{env, j_column_names}; + auto const column_names = n_column_names.as_cpp_vector(); + + auto const line_terminator = cudf::jni::native_jstring{env, j_row_delimiter}; + auto const na_rep = cudf::jni::native_jstring{env, j_null_value}; + auto const true_value = cudf::jni::native_jstring{env, j_true_value}; + auto const false_value = cudf::jni::native_jstring{env, j_false_value}; + + auto options = cudf::io::csv_writer_options::builder(cudf::io::sink_info{data_sink.get()}, + cudf::table_view{}) + .names(column_names) + .include_header(static_cast(include_header)) + .line_terminator(line_terminator.get()) + .inter_column_delimiter(j_field_delimiter) + .na_rep(na_rep.get()) + .true_value(true_value.get()) + .false_value(false_value.get()) + .build(); + + return ptr_as_jlong(new cudf::jni::io::csv_chunked_writer{options, data_sink}); + } + CATCH_STD(env, 0); +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeCSVChunkToBuffer(JNIEnv *env, jclass, + jlong j_writer_handle, + jlong j_table_handle) { + JNI_NULL_CHECK(env, j_writer_handle, "writer handle cannot be null.", ); + JNI_NULL_CHECK(env, j_table_handle, "table handle cannot be null.", ); + + auto const table = reinterpret_cast(j_table_handle); + auto writer = reinterpret_cast(j_writer_handle); + + try { + cudf::jni::auto_set_device(env); + writer->write(*table); + } + CATCH_STD(env, ); +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_endWriteCSVToBuffer(JNIEnv *env, jclass, + jlong j_writer_handle) { + JNI_NULL_CHECK(env, j_writer_handle, "writer handle cannot be null.", ); + + using cudf::jni::io::csv_chunked_writer; + auto writer = + std::unique_ptr{reinterpret_cast(j_writer_handle)}; + + try { + cudf::jni::auto_set_device(env); + writer->close(); + } + CATCH_STD(env, ); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readAndInferJSON( JNIEnv *env, jclass, jlong buffer, jlong buffer_length, jboolean day_first, jboolean lines) { diff --git a/java/src/main/native/src/csv_chunked_writer.hpp b/java/src/main/native/src/csv_chunked_writer.hpp new file mode 100644 index 00000000000..1f1e73a1a4b --- /dev/null +++ b/java/src/main/native/src/csv_chunked_writer.hpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include "jni_writer_data_sink.hpp" + +namespace cudf::jni::io { + +/** + * @brief Class to write multiple Tables into the jni_writer_data_sink. + */ +class csv_chunked_writer { + + cudf::io::csv_writer_options _options; + std::unique_ptr _sink; + + bool _first_write_completed = false; ///< Decides if header should be written. + +public: + explicit csv_chunked_writer(cudf::io::csv_writer_options options, + std::unique_ptr &sink) + : _options{options}, _sink{std::move(sink)} { + auto const &sink_info = _options.get_sink(); + // Assert invariants. + CUDF_EXPECTS(sink_info.type() != cudf::io::io_type::FILEPATH, + "Currently, chunked CSV writes to files is not supported."); + + // Note: csv_writer_options ties the sink(s) to the options, and exposes + // no way to modify the sinks afterwards. + // Ideally, the options would have been separate from the tables written, + // and the destination sinks. + // Here, we retain a modifiable reference to the sink, and confirm the + // options point to the same sink. + CUDF_EXPECTS(sink_info.num_sinks() == 1, "csv_chunked_writer should have exactly one sink."); + CUDF_EXPECTS(sink_info.user_sinks()[0] == _sink.get(), "Sink mismatch."); + } + + void write(cudf::table_view const &table) { + if (_first_write_completed) { + _options.enable_include_header(false); // Don't write header after the first write. + } + + _options.set_table(table); + _options.set_rows_per_chunk(table.num_rows()); + + cudf::io::write_csv(_options); + _first_write_completed = true; + } + + void close() { + // Flush pending writes to sink. + _sink->flush(); + } +}; + +} // namespace cudf::jni::io diff --git a/java/src/main/native/src/jni_writer_data_sink.hpp b/java/src/main/native/src/jni_writer_data_sink.hpp new file mode 100644 index 00000000000..05fe594fcd5 --- /dev/null +++ b/java/src/main/native/src/jni_writer_data_sink.hpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cudf_jni_apis.hpp" +#include "jni_utils.hpp" + +namespace cudf::jni { + +constexpr long MINIMUM_WRITE_BUFFER_SIZE = 10 * 1024 * 1024; // 10 MB + +class jni_writer_data_sink final : public cudf::io::data_sink { +public: + explicit jni_writer_data_sink(JNIEnv *env, jobject callback) { + if (env->GetJavaVM(&jvm) < 0) { + throw std::runtime_error("GetJavaVM failed"); + } + + jclass cls = env->GetObjectClass(callback); + if (cls == nullptr) { + throw cudf::jni::jni_exception("class not found"); + } + + handle_buffer_method = + env->GetMethodID(cls, "handleBuffer", "(Lai/rapids/cudf/HostMemoryBuffer;J)V"); + if (handle_buffer_method == nullptr) { + throw cudf::jni::jni_exception("handleBuffer method"); + } + + this->callback = env->NewGlobalRef(callback); + if (this->callback == nullptr) { + throw cudf::jni::jni_exception("global ref"); + } + } + + virtual ~jni_writer_data_sink() { + // This should normally be called by a JVM thread. If the JVM environment is missing then this + // is likely being triggered by the C++ runtime during shutdown. In that case the JVM may + // already be destroyed and this thread should not try to attach to get an environment. + JNIEnv *env = nullptr; + if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { + env->DeleteGlobalRef(callback); + if (current_buffer != nullptr) { + env->DeleteGlobalRef(current_buffer); + } + } + callback = nullptr; + current_buffer = nullptr; + } + + void host_write(void const *data, size_t size) override { + JNIEnv *env = cudf::jni::get_jni_env(jvm); + long left_to_copy = static_cast(size); + const char *copy_from = static_cast(data); + while (left_to_copy > 0) { + long buffer_amount_available = current_buffer_len - current_buffer_written; + if (buffer_amount_available <= 0) { + // should never be < 0, but just to be safe + rotate_buffer(env); + buffer_amount_available = current_buffer_len - current_buffer_written; + } + long amount_to_copy = + left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available; + char *copy_to = current_buffer_data + current_buffer_written; + + std::memcpy(copy_to, copy_from, amount_to_copy); + copy_from = copy_from + amount_to_copy; + current_buffer_written += amount_to_copy; + total_written += amount_to_copy; + left_to_copy -= amount_to_copy; + } + } + + bool supports_device_write() const override { return true; } + + void device_write(void const *gpu_data, size_t size, rmm::cuda_stream_view stream) override { + JNIEnv *env = cudf::jni::get_jni_env(jvm); + long left_to_copy = static_cast(size); + const char *copy_from = static_cast(gpu_data); + while (left_to_copy > 0) { + long buffer_amount_available = current_buffer_len - current_buffer_written; + if (buffer_amount_available <= 0) { + // should never be < 0, but just to be safe + stream.synchronize(); + rotate_buffer(env); + buffer_amount_available = current_buffer_len - current_buffer_written; + } + long amount_to_copy = + left_to_copy < buffer_amount_available ? left_to_copy : buffer_amount_available; + char *copy_to = current_buffer_data + current_buffer_written; + + CUDF_CUDA_TRY(cudaMemcpyAsync(copy_to, copy_from, amount_to_copy, cudaMemcpyDeviceToHost, + stream.value())); + + copy_from = copy_from + amount_to_copy; + current_buffer_written += amount_to_copy; + total_written += amount_to_copy; + left_to_copy -= amount_to_copy; + } + stream.synchronize(); + } + + std::future device_write_async(void const *gpu_data, size_t size, + rmm::cuda_stream_view stream) override { + // Call the sync version until figuring out how to write asynchronously. + device_write(gpu_data, size, stream); + return std::async(std::launch::deferred, [] {}); + } + + void flush() override { + if (current_buffer_written > 0) { + JNIEnv *env = cudf::jni::get_jni_env(jvm); + handle_buffer(env, current_buffer, current_buffer_written); + if (current_buffer != nullptr) { + env->DeleteGlobalRef(current_buffer); + } + current_buffer = nullptr; + current_buffer_len = 0; + current_buffer_data = nullptr; + current_buffer_written = 0; + } + } + + size_t bytes_written() override { return total_written; } + + void set_alloc_size(long size) { this->alloc_size = size; } + +private: + void rotate_buffer(JNIEnv *env) { + if (current_buffer != nullptr) { + handle_buffer(env, current_buffer, current_buffer_written); + env->DeleteGlobalRef(current_buffer); + current_buffer = nullptr; + } + jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true); + current_buffer = env->NewGlobalRef(tmp_buffer); + current_buffer_len = get_host_buffer_length(env, current_buffer); + current_buffer_data = reinterpret_cast(get_host_buffer_address(env, current_buffer)); + current_buffer_written = 0; + } + + void handle_buffer(JNIEnv *env, jobject buffer, jlong len) { + env->CallVoidMethod(callback, handle_buffer_method, buffer, len); + if (env->ExceptionCheck()) { + throw std::runtime_error("handleBuffer threw an exception"); + } + } + + JavaVM *jvm; + jobject callback; + jmethodID handle_buffer_method; + jobject current_buffer = nullptr; + char *current_buffer_data = nullptr; + long current_buffer_len = 0; + long current_buffer_written = 0; + size_t total_written = 0; + long alloc_size = MINIMUM_WRITE_BUFFER_SIZE; +}; + +} // namespace cudf::jni diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index bf951a871e7..83e4cb536f3 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,6 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.*; -import java.util.function.IntFunction; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -575,6 +574,123 @@ void testReadCSV() { } } + private void testWriteCSVToFileImpl(char fieldDelim, boolean includeHeader, + String trueValue, String falseValue) throws IOException { + File outputFile = File.createTempFile("testWriteCSVToFile", ".csv"); + Schema schema = Schema.builder() + .column(DType.INT32, "i") + .column(DType.FLOAT64, "f") + .column(DType.BOOL8, "b") + .column(DType.STRING, "str") + .build(); + CSVWriterOptions writeOptions = CSVWriterOptions.builder() + .withColumnNames(schema.getColumnNames()) + .withIncludeHeader(false) + .withFieldDelimiter((byte)'\u0001') + .withRowDelimiter("\n") + .withNullValue("\\N") + .withTrueValue("T") + .withFalseValue("F") + .build(); + try (Table inputTable + = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0) + .column(false, true, false, true, false, true, false, true, false, true) + .column("All", "the", "leaves", "are", "brown", "and", "the", "sky", "is", "grey") + .build()) { + inputTable.writeCSVToFile(writeOptions, outputFile.getAbsolutePath()); + + // Read back. + CSVOptions readOptions = CSVOptions.builder() + .includeColumn("i") + .includeColumn("f") + .includeColumn("b") + .includeColumn("str") + .hasHeader(false) + .withDelim('\u0001') + .withTrueValue("T") + .withFalseValue("F") + .build(); + try (Table readTable = Table.readCSV(schema, readOptions, outputFile)) { + assertTablesAreEqual(inputTable, readTable); + } + } finally { + outputFile.delete(); + } + } + + @Test + void testWriteCSVToFile() throws IOException { + final boolean INCLUDE_HEADER = true; + final boolean NO_HEADER = false; + testWriteCSVToFileImpl(',', INCLUDE_HEADER, "true", "false"); + testWriteCSVToFileImpl(',', NO_HEADER, "TRUE", "FALSE"); + testWriteCSVToFileImpl('\u0001', INCLUDE_HEADER, "T", "F"); + testWriteCSVToFileImpl('\u0001', NO_HEADER, "True", "False"); + } + + private void testChunkedCSVWriterImpl(char fieldDelim, boolean includeHeader, + String trueValue, String falseValue) throws IOException { + Schema schema = Schema.builder() + .column(DType.INT32, "i") + .column(DType.FLOAT64, "f") + .column(DType.BOOL8, "b") + .column(DType.STRING, "str") + .build(); + CSVWriterOptions writeOptions = CSVWriterOptions.builder() + .withColumnNames(schema.getColumnNames()) + .withIncludeHeader(includeHeader) + .withFieldDelimiter((byte)fieldDelim) + .withRowDelimiter("\n") + .withNullValue("\\N") + .withTrueValue(trueValue) + .withFalseValue(falseValue) + .build(); + try (Table inputTable + = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 8, null) + .column(0.0, 1.0, 2.0, 3.0, 4.0, null, 6.0, 7.0, 8.0, 9.0) + .column(false, true, null, true, false, true, null, true, false, true) + .column("All", "the", "leaves", "are", "brown", "and", "the", "sky", "is", null) + .build(); + MyBufferConsumer consumer = new MyBufferConsumer()) { + + try (TableWriter writer = Table.getCSVBufferWriter(writeOptions, consumer)) { + writer.write(inputTable); + writer.write(inputTable); + writer.write(inputTable); + } + + // Read back. + CSVOptions readOptions = CSVOptions.builder() + .includeColumn("i") + .includeColumn("f") + .includeColumn("b") + .includeColumn("str") + .hasHeader(includeHeader) + .withDelim(fieldDelim) + .withNullValue("\\N") + .withTrueValue(trueValue) + .withFalseValue(falseValue) + .build(); + try (Table readTable = Table.readCSV(schema, readOptions, consumer.buffer, 0, consumer.offset); + Table expected = Table.concatenate(inputTable, inputTable, inputTable)) { + assertTablesAreEqual(expected, readTable); + } + } + } + + @Test + void testChunkedCSVWriter() throws IOException { + final boolean INCLUDE_HEADER = true; + final boolean NO_HEADER = false; + testChunkedCSVWriterImpl(',', NO_HEADER, "true", "false"); + testChunkedCSVWriterImpl(',', INCLUDE_HEADER, "TRUE", "FALSE"); + testChunkedCSVWriterImpl('\u0001', NO_HEADER, "T", "F"); + testChunkedCSVWriterImpl('\u0001', INCLUDE_HEADER, "True", "False"); + } + @Test void testReadParquet() { ParquetOptions opts = ParquetOptions.builder()